Implement the node C# function (#2072)

1. Ports the `node` function from Python to C#.
2. Adds a missing authentication check.
3. Add validated string type `PoolName` for consistency with Python version.
This commit is contained in:
George Pollard
2022-06-23 13:44:14 +12:00
committed by GitHub
parent 1eeefce85c
commit 4eec0bfc45
30 changed files with 870 additions and 109 deletions

View File

@ -0,0 +1,179 @@
using System.Threading.Tasks;
using Microsoft.Azure.Functions.Worker;
using Microsoft.Azure.Functions.Worker.Http;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
namespace Microsoft.OneFuzz.Service;
public class NodeFunction {
private readonly ILogTracer _log;
private readonly IEndpointAuthorization _auth;
private readonly IOnefuzzContext _context;
public NodeFunction(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) {
_log = log;
_auth = auth;
_context = context;
}
private static readonly EntityConverter _entityConverter = new();
// [Function("Node")
public Async.Task<HttpResponseData> Run([HttpTrigger("GET", "PATCH", "POST", "DELETE")] HttpRequestData req) {
return _auth.CallIfUser(req, r => r.Method switch {
"GET" => Get(r),
"PATCH" => Patch(r),
"POST" => Post(r),
"DELETE" => Delete(r),
_ => throw new InvalidOperationException("Unsupported HTTP method"),
});
}
private async Async.Task<HttpResponseData> Get(HttpRequestData req) {
var request = await RequestHandling.ParseRequest<NodeSearch>(req);
if (!request.IsOk) {
return await _context.RequestHandling.NotOk(req, request.ErrorV, "pool get");
}
var search = request.OkV;
if (search.MachineId is Guid machineId) {
var node = await _context.NodeOperations.GetByMachineId(machineId);
if (node is null) {
return await _context.RequestHandling.NotOk(
req,
new Error(
Code: ErrorCode.UNABLE_TO_FIND,
Errors: new string[] { "unable to find node " }),
context: machineId.ToString());
}
var (tasks, messages) = await (
_context.NodeTasksOperations.GetByMachineId(machineId).ToListAsync().AsTask(),
_context.NodeMessageOperations.GetMessage(machineId).ToListAsync().AsTask());
var commands = messages.Select(m => m.Message).ToList();
return await RequestHandling.Ok(req, NodeToNodeSearchResult(node with { Tasks = tasks, Messages = commands }));
}
var nodes = await _context.NodeOperations.SearchStates(
states: search.State,
poolName: search.PoolName,
scaleSetId: search.ScalesetId).ToListAsync();
return await RequestHandling.Ok(req, nodes.Select(NodeToNodeSearchResult));
}
private static NodeSearchResult NodeToNodeSearchResult(Node node) {
return new NodeSearchResult(
PoolId: node.PoolId,
PoolName: node.PoolName,
MachineId: node.MachineId,
Version: node.Version,
Heartbeat: node.Heartbeat,
InitializedAt: node.InitializedAt,
State: node.State,
ScalesetId: node.ScalesetId,
ReimageRequested: node.ReimageRequested,
DeleteRequested: node.DeleteRequested,
DebugKeepNode: node.DebugKeepNode);
}
private async Async.Task<HttpResponseData> Patch(HttpRequestData req) {
var request = await RequestHandling.ParseRequest<NodeGet>(req);
if (!request.IsOk) {
return await _context.RequestHandling.NotOk(
req,
request.ErrorV,
"NodeReimage");
}
var authCheck = await _auth.CheckRequireAdmins(req);
if (!authCheck.IsOk) {
return await _context.RequestHandling.NotOk(req, authCheck.ErrorV, "NodeReimage");
}
var patch = request.OkV;
var node = await _context.NodeOperations.GetByMachineId(patch.MachineId);
if (node is null) {
return await _context.RequestHandling.NotOk(
req,
new Error(
Code: ErrorCode.UNABLE_TO_FIND,
Errors: new string[] { "unable to find node " }),
context: patch.MachineId.ToString());
}
await _context.NodeOperations.Stop(node, done: true);
if (node.DebugKeepNode) {
await _context.NodeOperations.Replace(node with { DebugKeepNode = false });
}
return await RequestHandling.Ok(req, true);
}
private async Async.Task<HttpResponseData> Post(HttpRequestData req) {
var request = await RequestHandling.ParseRequest<NodeUpdate>(req);
if (!request.IsOk) {
return await _context.RequestHandling.NotOk(
req,
request.ErrorV,
"NodeUpdate");
}
var authCheck = await _auth.CheckRequireAdmins(req);
if (!authCheck.IsOk) {
return await _context.RequestHandling.NotOk(req, authCheck.ErrorV, "NodeUpdate");
}
var post = request.OkV;
var node = await _context.NodeOperations.GetByMachineId(post.MachineId);
if (node is null) {
return await _context.RequestHandling.NotOk(
req,
new Error(
Code: ErrorCode.UNABLE_TO_FIND,
Errors: new string[] { "unable to find node " }),
context: post.MachineId.ToString());
}
if (post.DebugKeepNode is bool value) {
node = node with { DebugKeepNode = value };
}
await _context.NodeOperations.Replace(node);
return await RequestHandling.Ok(req, true);
}
private async Async.Task<HttpResponseData> Delete(HttpRequestData req) {
var request = await RequestHandling.ParseRequest<NodeGet>(req);
if (!request.IsOk) {
return await _context.RequestHandling.NotOk(
req,
request.ErrorV,
context: "NodeDelete");
}
var authCheck = await _auth.CheckRequireAdmins(req);
if (!authCheck.IsOk) {
return await _context.RequestHandling.NotOk(req, authCheck.ErrorV, "NodeDelete");
}
var delete = request.OkV;
var node = await _context.NodeOperations.GetByMachineId(delete.MachineId);
if (node is null) {
return await _context.RequestHandling.NotOk(
req,
new Error(
Code: ErrorCode.UNABLE_TO_FIND,
new string[] { "unable to find node" }),
context: delete.MachineId.ToString());
}
await _context.NodeOperations.SetHalt(node);
if (node.DebugKeepNode) {
await _context.NodeOperations.Replace(node with { DebugKeepNode = false });
}
return await RequestHandling.Ok(req, true);
}
}

View File

@ -1,7 +1,6 @@
using System.Text.Json; using System.Text.Json;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm; using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
using PoolName = System.String;
using Region = System.String; using Region = System.String;
namespace Microsoft.OneFuzz.Service; namespace Microsoft.OneFuzz.Service;

View File

@ -3,7 +3,6 @@ using System.Text.Json.Serialization;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm; using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
using Endpoint = System.String; using Endpoint = System.String;
using GroupId = System.Guid; using GroupId = System.Guid;
using PoolName = System.String;
using PrincipalId = System.Guid; using PrincipalId = System.Guid;
using Region = System.String; using Region = System.String;
@ -152,7 +151,6 @@ public record Error(ErrorCode Code, string[]? Errors = null);
public record UserInfo(Guid? ApplicationId, Guid? ObjectId, String? Upn); public record UserInfo(Guid? ApplicationId, Guid? ObjectId, String? Upn);
public record TaskDetails( public record TaskDetails(
TaskType Type, TaskType Type,
int Duration, int Duration,
@ -316,11 +314,11 @@ public record InstanceConfig
NetworkSecurityGroupConfig ProxyNsgConfig, NetworkSecurityGroupConfig ProxyNsgConfig,
AzureVmExtensionConfig? Extensions, AzureVmExtensionConfig? Extensions,
string ProxyVmSku, string ProxyVmSku,
IDictionary<Endpoint, ApiAccessRule>? ApiAccessRules, IDictionary<Endpoint, ApiAccessRule>? ApiAccessRules = null,
IDictionary<PrincipalId, GroupId[]>? GroupMembership, IDictionary<PrincipalId, GroupId[]>? GroupMembership = null,
IDictionary<string, string>? VmTags = null,
IDictionary<string, string>? VmTags, IDictionary<string, string>? VmssTags = null,
IDictionary<string, string>? VmssTags bool? RequireAdminPrivileges = null
) : EntityBase() { ) : EntityBase() {
public InstanceConfig(string instanceName) : this( public InstanceConfig(string instanceName) : this(
instanceName, instanceName,
@ -330,12 +328,7 @@ public record InstanceConfig
new NetworkConfig(), new NetworkConfig(),
new NetworkSecurityGroupConfig(), new NetworkSecurityGroupConfig(),
null, null,
"Standard_B2s", "Standard_B2s") { }
null,
null,
null,
null) { }
public InstanceConfig() : this(String.Empty) { } public InstanceConfig() : this(String.Empty) { }
public List<Guid>? CheckAdmins(List<Guid>? value) { public List<Guid>? CheckAdmins(List<Guid>? value) {
@ -346,7 +339,6 @@ public record InstanceConfig
} }
} }
//# At the moment, this only checks allowed_aad_tenants, however adding //# At the moment, this only checks allowed_aad_tenants, however adding
//# support for 3rd party JWT validation is anticipated in a future release. //# support for 3rd party JWT validation is anticipated in a future release.
public ResultVoid<List<string>> CheckInstanceConfig() { public ResultVoid<List<string>> CheckInstanceConfig() {

View File

@ -18,6 +18,22 @@ public record NodeCommandDelete(
string MessageId string MessageId
) : BaseRequest; ) : BaseRequest;
public record NodeGet(
Guid MachineId
) : BaseRequest;
public record NodeUpdate(
Guid MachineId,
bool? DebugKeepNode
) : BaseRequest;
public record NodeSearch(
Guid? MachineId = null,
List<NodeState>? State = null,
Guid? ScalesetId = null,
PoolName? PoolName = null
) : BaseRequest;
public record NodeStateEnvelope( public record NodeStateEnvelope(
NodeEventBase Event, NodeEventBase Event,
Guid MachineId Guid MachineId

View File

@ -4,7 +4,10 @@ using System.Text.Json.Serialization;
namespace Microsoft.OneFuzz.Service; namespace Microsoft.OneFuzz.Service;
[JsonConverter(typeof(BaseResponseConverter))] [JsonConverter(typeof(BaseResponseConverter))]
public abstract record BaseResponse(); public abstract record BaseResponse() {
public static implicit operator BaseResponse(bool value)
=> new BoolResult(value);
};
public record CanSchedule( public record CanSchedule(
bool Allowed, bool Allowed,
@ -15,6 +18,23 @@ public record PendingNodeCommand(
NodeCommandEnvelope? Envelope NodeCommandEnvelope? Envelope
) : BaseResponse(); ) : BaseResponse();
// TODO: not sure how much of this is actually
// needed in the search results, so at the moment
// it is a copy of the whole Node type
public record NodeSearchResult(
PoolName PoolName,
Guid MachineId,
Guid? PoolId,
string Version,
DateTimeOffset? Heartbeat,
DateTimeOffset? InitializedAt,
NodeState State,
Guid? ScalesetId,
bool ReimageRequested,
bool DeleteRequested,
bool DebugKeepNode
) : BaseResponse();
public record BoolResult( public record BoolResult(
bool Result bool Result
) : BaseResponse(); ) : BaseResponse();

View File

@ -54,6 +54,9 @@ namespace Microsoft.OneFuzz.Service {
public static OneFuzzResult<T_Ok> Error(ErrorCode errorCode, string error) => new(errorCode, new[] { error }); public static OneFuzzResult<T_Ok> Error(ErrorCode errorCode, string error) => new(errorCode, new[] { error });
public static OneFuzzResult<T_Ok> Error(Error err) => new(err); public static OneFuzzResult<T_Ok> Error(Error err) => new(err);
// Allow simple conversion of Errors to Results.
public static implicit operator OneFuzzResult<T_Ok>(Error err) => new(err);
} }
public struct OneFuzzResultVoid { public struct OneFuzzResultVoid {
@ -69,9 +72,12 @@ namespace Microsoft.OneFuzz.Service {
private OneFuzzResultVoid(Error err) => (ErrorV, IsOk) = (err, false); private OneFuzzResultVoid(Error err) => (ErrorV, IsOk) = (err, false);
public static OneFuzzResultVoid Ok() => new(); public static OneFuzzResultVoid Ok => new();
public static OneFuzzResultVoid Error(ErrorCode errorCode, string[] errors) => new(errorCode, errors); public static OneFuzzResultVoid Error(ErrorCode errorCode, string[] errors) => new(errorCode, errors);
public static OneFuzzResultVoid Error(ErrorCode errorCode, string error) => new(errorCode, new[] { error }); public static OneFuzzResultVoid Error(ErrorCode errorCode, string error) => new(errorCode, new[] { error });
public static OneFuzzResultVoid Error(Error err) => new(err); public static OneFuzzResultVoid Error(Error err) => new(err);
// Allow simple conversion of Errors to Results.
public static implicit operator OneFuzzResultVoid(Error err) => new(err);
} }
} }

View File

@ -0,0 +1,142 @@

using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text.RegularExpressions;
namespace Microsoft.OneFuzz.Service;
static class Check {
private static readonly Regex _isAlnum = new(@"\A[a-zA-Z0-9]+\z", RegexOptions.Compiled);
public static bool IsAlnum(string input) => _isAlnum.IsMatch(input);
private static readonly Regex _isAlnumDash = new(@"\A[a-zA-Z0-9\-]+\z", RegexOptions.Compiled);
public static bool IsAlnumDash(string input) => _isAlnumDash.IsMatch(input);
}
// Base class for types that are wrappers around a validated string.
public abstract record ValidatedString(string String) {
public sealed override string ToString() => String;
}
// JSON converter for types that are wrappers around a validated string.
public abstract class ValidatedStringConverter<T> : JsonConverter<T> where T : ValidatedString {
protected abstract bool TryParse(string input, out T? output);
public sealed override bool CanConvert(Type typeToConvert)
=> typeToConvert == typeof(T);
public sealed override T? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) {
if (reader.TokenType != JsonTokenType.String) {
throw new JsonException("expected a string");
}
var value = reader.GetString();
if (value is null) {
throw new JsonException("expected a string");
}
if (TryParse(value, out var result)) {
return result;
} else {
throw new JsonException($"unable to parse input as a {typeof(T).Name}");
}
}
public sealed override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options)
=> writer.WriteStringValue(value.String);
}
[JsonConverter(typeof(Converter))]
public record PoolName : ValidatedString {
private PoolName(string value) : base(value) {
Debug.Assert(Check.IsAlnumDash(value));
}
public static PoolName Parse(string input) {
if (TryParse(input, out var result)) {
return result;
}
throw new ArgumentException("Pool name must have only numbers, letters or dashes");
}
public static bool TryParse(string input, [NotNullWhen(returnValue: true)] out PoolName? result) {
if (!Check.IsAlnumDash(input)) {
result = default;
return false;
}
result = new PoolName(input);
return true;
}
public sealed class Converter : ValidatedStringConverter<PoolName> {
protected override bool TryParse(string input, out PoolName? output)
=> PoolName.TryParse(input, out output);
}
}
/* TODO: to be enabled in a separate PR
[JsonConverter(typeof(Converter))]
public record Region : ValidatedString {
private Region(string value) : base(value) {
Debug.Assert(Check.IsAlnum(value));
}
public static Region Parse(string input) {
if (TryParse(input, out var result)) {
return result;
}
throw new ArgumentException("Region name must have only numbers, letters or dashes");
}
public static bool TryParse(string input, [NotNullWhen(returnValue: true)] out Region? result) {
if (!Check.IsAlnum(input)) {
result = default;
return false;
}
result = new Region(input);
return true;
}
public sealed class Converter : ValidatedStringConverter<Region> {
protected override bool TryParse(string input, out Region? output)
=> Region.TryParse(input, out output);
}
}
[JsonConverter(typeof(Converter))]
public record Container : ValidatedString {
private Container(string value) : base(value) {
Debug.Assert(Check.IsAlnumDash(value));
}
public static Container Parse(string input) {
if (TryParse(input, out var result)) {
return result;
}
throw new ArgumentException("Container name must have only numbers, letters or dashes");
}
public static bool TryParse(string input, [NotNullWhen(returnValue: true)] out Container? result) {
if (!Check.IsAlnumDash(input)) {
result = default;
return false;
}
result = new Container(input);
return true;
}
public sealed class Converter : ValidatedStringConverter<Container> {
protected override bool TryParse(string input, out Container? output)
=> Container.TryParse(input, out output);
}
}
*/

View File

@ -165,7 +165,9 @@ namespace ApiService.TestHooks {
if (query.ContainsKey("states")) { if (query.ContainsKey("states")) {
states = query["states"].Split('-').Select(s => Enum.Parse<NodeState>(s)).ToList(); states = query["states"].Split('-').Select(s => Enum.Parse<NodeState>(s)).ToList();
} }
string? poolName = UriExtension.GetString("poolName", query); string? poolNameString = UriExtension.GetString("poolName", query);
PoolName? poolName = poolNameString is null ? null : PoolName.Parse(poolNameString);
var excludeUpdateScheduled = UriExtension.GetBool("excludeUpdateScheduled", query, false); var excludeUpdateScheduled = UriExtension.GetBool("excludeUpdateScheduled", query, false);
int? numResults = UriExtension.GetInt("numResults", query); int? numResults = UriExtension.GetInt("numResults", query);
@ -209,7 +211,7 @@ namespace ApiService.TestHooks {
var query = UriExtension.GetQueryComponents(req.Url); var query = UriExtension.GetQueryComponents(req.Url);
Guid poolId = Guid.Parse(query["poolId"]); Guid poolId = Guid.Parse(query["poolId"]);
string poolName = query["poolName"]; var poolName = PoolName.Parse(query["poolName"]);
Guid machineId = Guid.Parse(query["machineId"]); Guid machineId = Guid.Parse(query["machineId"]);
Guid? scaleSetId = default; Guid? scaleSetId = default;

View File

@ -25,7 +25,7 @@ namespace ApiService.TestHooks {
_log.Info("get pool"); _log.Info("get pool");
var query = UriExtension.GetQueryComponents(req.Url); var query = UriExtension.GetQueryComponents(req.Url);
var poolRes = await _poolOps.GetByName(query["name"]); var poolRes = await _poolOps.GetByName(PoolName.Parse(query["name"]));
if (poolRes.IsOk) { if (poolRes.IsOk) {
var resp = req.CreateResponse(HttpStatusCode.OK); var resp = req.CreateResponse(HttpStatusCode.OK);

View File

@ -58,7 +58,7 @@ public class UserCredentials : IUserCredentials {
return OneFuzzResult<string[]>.Ok(allowedAddTenantsQuery.ToArray()); return OneFuzzResult<string[]>.Ok(allowedAddTenantsQuery.ToArray());
} }
public async Task<OneFuzzResult<UserInfo>> ParseJwtToken(HttpRequestData req) { public virtual async Task<OneFuzzResult<UserInfo>> ParseJwtToken(HttpRequestData req) {
var authToken = GetAuthToken(req); var authToken = GetAuthToken(req);
if (authToken is null) { if (authToken is null) {
return OneFuzzResult<UserInfo>.Error(ErrorCode.INVALID_REQUEST, new[] { "unable to find authorization token" }); return OneFuzzResult<UserInfo>.Error(ErrorCode.INVALID_REQUEST, new[] { "unable to find authorization token" });

View File

@ -19,6 +19,8 @@ public interface IEndpointAuthorization {
Func<HttpRequestData, Async.Task<HttpResponseData>> method, Func<HttpRequestData, Async.Task<HttpResponseData>> method,
bool allowUser = false, bool allowUser = false,
bool allowAgent = false); bool allowAgent = false);
Async.Task<OneFuzzResultVoid> CheckRequireAdmins(HttpRequestData req);
} }
public class EndpointAuthorization : IEndpointAuthorization { public class EndpointAuthorization : IEndpointAuthorization {
@ -30,7 +32,7 @@ public class EndpointAuthorization : IEndpointAuthorization {
_log = log; _log = log;
} }
public async Async.Task<HttpResponseData> CallIf(HttpRequestData req, Func<HttpRequestData, Async.Task<HttpResponseData>> method, bool allowUser = false, bool allowAgent = false) { public virtual async Async.Task<HttpResponseData> CallIf(HttpRequestData req, Func<HttpRequestData, Async.Task<HttpResponseData>> method, bool allowUser = false, bool allowAgent = false) {
var tokenResult = await _context.UserCredentials.ParseJwtToken(req); var tokenResult = await _context.UserCredentials.ParseJwtToken(req);
if (!tokenResult.IsOk) { if (!tokenResult.IsOk) {
@ -77,6 +79,59 @@ public class EndpointAuthorization : IEndpointAuthorization {
); );
} }
public async Async.Task<OneFuzzResultVoid> CheckRequireAdmins(HttpRequestData req) {
var tokenResult = await _context.UserCredentials.ParseJwtToken(req);
if (!tokenResult.IsOk) {
return tokenResult.ErrorV;
}
var config = await _context.ConfigOperations.Fetch();
if (config is null) {
return new Error(
Code: ErrorCode.INVALID_CONFIGURATION,
Errors: new string[] { "no instance configuration found " });
}
return CheckRequireAdminsImpl(config, tokenResult.OkV);
}
private static OneFuzzResultVoid CheckRequireAdminsImpl(InstanceConfig config, UserInfo userInfo) {
// When there are no admins in the `admins` list, all users are considered
// admins. However, `require_admin_privileges` is still useful to protect from
// mistakes.
//
// To make changes while still protecting against accidental changes to
// pools, do the following:
//
// 1. set `require_admin_privileges` to `False`
// 2. make the change
// 3. set `require_admin_privileges` to `True`
if (config.RequireAdminPrivileges == false) {
return OneFuzzResultVoid.Ok;
}
if (config.Admins is null) {
return new Error(
Code: ErrorCode.UNAUTHORIZED,
Errors: new string[] { "pool modification disabled " });
}
if (userInfo.ObjectId is Guid objectId) {
if (config.Admins.Contains(objectId)) {
return OneFuzzResultVoid.Ok;
}
return new Error(
Code: ErrorCode.UNAUTHORIZED,
Errors: new string[] { "not authorized to manage pools" });
} else {
return new Error(
Code: ErrorCode.UNAUTHORIZED,
Errors: new string[] { "user had no Object ID" });
}
}
public OneFuzzResultVoid CheckAccess(HttpRequestData req) { public OneFuzzResultVoid CheckAccess(HttpRequestData req) {
throw new NotImplementedException(); throw new NotImplementedException();
} }

View File

@ -22,7 +22,7 @@ public interface INodeOperations : IStatefulOrm<Node, NodeState> {
IAsyncEnumerable<Node> SearchStates(Guid? poolId = default, IAsyncEnumerable<Node> SearchStates(Guid? poolId = default,
Guid? scaleSetId = default, Guid? scaleSetId = default,
IEnumerable<NodeState>? states = default, IEnumerable<NodeState>? states = default,
string? poolName = default, PoolName? poolName = default,
bool excludeUpdateScheduled = false, bool excludeUpdateScheduled = false,
int? numResults = default); int? numResults = default);
@ -32,7 +32,7 @@ public interface INodeOperations : IStatefulOrm<Node, NodeState> {
Async.Task<Node> Create( Async.Task<Node> Create(
Guid poolId, Guid poolId,
string poolName, PoolName poolName,
Guid machineId, Guid machineId,
Guid? scaleSetId, Guid? scaleSetId,
string version, string version,
@ -67,7 +67,7 @@ public class NodeOperations : StatefulOrm<Node, NodeState>, INodeOperations {
_logTracer.Info($"Setting scale-in protection on node {node.MachineId}"); _logTracer.Info($"Setting scale-in protection on node {node.MachineId}");
return await _context.VmssOperations.UpdateScaleInProtection((Guid)node.ScalesetId, node.MachineId, protectFromScaleIn: true); return await _context.VmssOperations.UpdateScaleInProtection((Guid)node.ScalesetId, node.MachineId, protectFromScaleIn: true);
} }
return OneFuzzResultVoid.Ok(); return OneFuzzResultVoid.Ok;
} }
public async Async.Task<bool> ScalesetNodeExists(Node node) { public async Async.Task<bool> ScalesetNodeExists(Node node) {
@ -207,7 +207,7 @@ public class NodeOperations : StatefulOrm<Node, NodeState>, INodeOperations {
public async Async.Task<Node> Create( public async Async.Task<Node> Create(
Guid poolId, Guid poolId,
string poolName, PoolName poolName,
Guid machineId, Guid machineId,
Guid? scaleSetId, Guid? scaleSetId,
string version, string version,
@ -308,7 +308,7 @@ public class NodeOperations : StatefulOrm<Node, NodeState>, INodeOperations {
Guid? poolId = default, Guid? poolId = default,
Guid? scaleSetId = default, Guid? scaleSetId = default,
IEnumerable<NodeState>? states = default, IEnumerable<NodeState>? states = default,
string? poolName = default, PoolName? poolName = default,
bool excludeUpdateScheduled = false, bool excludeUpdateScheduled = false,
int? numResults = default) { int? numResults = default) {
@ -318,6 +318,10 @@ public class NodeOperations : StatefulOrm<Node, NodeState>, INodeOperations {
queryParts.Add($"(pool_id eq '{poolId}')"); queryParts.Add($"(pool_id eq '{poolId}')");
} }
if (poolName is not null) {
queryParts.Add($"(PartitionKey eq '{poolName}')");
}
if (scaleSetId is not null) { if (scaleSetId is not null) {
queryParts.Add($"(scaleset_id eq '{scaleSetId}')"); queryParts.Add($"(scaleset_id eq '{scaleSetId}')");
} }
@ -346,7 +350,7 @@ public class NodeOperations : StatefulOrm<Node, NodeState>, INodeOperations {
Guid? poolId = default, Guid? poolId = default,
Guid? scaleSetId = default, Guid? scaleSetId = default,
IEnumerable<NodeState>? states = default, IEnumerable<NodeState>? states = default,
string? poolName = default, PoolName? poolName = default,
bool excludeUpdateScheduled = false, bool excludeUpdateScheduled = false,
int? numResults = default) { int? numResults = default) {
var query = NodeOperations.SearchStatesQuery(_context.ServiceConfiguration.OneFuzzVersion, poolId, scaleSetId, states, poolName, excludeUpdateScheduled, numResults); var query = NodeOperations.SearchStatesQuery(_context.ServiceConfiguration.OneFuzzVersion, poolId, scaleSetId, states, poolName, excludeUpdateScheduled, numResults);
@ -467,9 +471,8 @@ public class NodeMessageOperations : Orm<NodeMessage>, INodeMessageOperations {
_log = log; _log = log;
} }
public IAsyncEnumerable<NodeMessage> GetMessage(Guid machineId) { public IAsyncEnumerable<NodeMessage> GetMessage(Guid machineId)
return QueryAsync($"PartitionKey eq '{machineId}'"); => QueryAsync(Query.PartitionKey(machineId));
}
public async Async.Task ClearMessages(Guid machineId) { public async Async.Task ClearMessages(Guid machineId) {
_logTracer.Info($"clearing messages for node {machineId}"); _logTracer.Info($"clearing messages for node {machineId}");

View File

@ -48,7 +48,7 @@ namespace Microsoft.OneFuzz.Service {
public async Async.Task<OneFuzzResultVoid> DissociateNic(Nsg nsg, NetworkInterfaceResource nic) { public async Async.Task<OneFuzzResultVoid> DissociateNic(Nsg nsg, NetworkInterfaceResource nic) {
if (nic.Data.NetworkSecurityGroup == null) { if (nic.Data.NetworkSecurityGroup == null) {
return OneFuzzResultVoid.Ok(); return OneFuzzResultVoid.Ok;
} }
var azureNsg = await GetNsg(nsg.Name); var azureNsg = await GetNsg(nsg.Name);
@ -83,7 +83,7 @@ namespace Microsoft.OneFuzz.Service {
err, err,
) )
*/ */
return OneFuzzResultVoid.Ok(); return OneFuzzResultVoid.Ok;
} }
return OneFuzzResultVoid.Error( return OneFuzzResultVoid.Error(
ErrorCode.UNABLE_TO_UPDATE, ErrorCode.UNABLE_TO_UPDATE,
@ -93,7 +93,7 @@ namespace Microsoft.OneFuzz.Service {
); );
} }
return OneFuzzResultVoid.Ok(); return OneFuzzResultVoid.Ok;
} }
public async Async.Task<NetworkSecurityGroupResource?> GetNsg(string name) { public async Async.Task<NetworkSecurityGroupResource?> GetNsg(string name) {

View File

@ -4,7 +4,7 @@ using ApiService.OneFuzzLib.Orm;
namespace Microsoft.OneFuzz.Service; namespace Microsoft.OneFuzz.Service;
public interface IPoolOperations { public interface IPoolOperations {
public Async.Task<OneFuzzResult<Pool>> GetByName(string poolName); public Async.Task<OneFuzzResult<Pool>> GetByName(PoolName poolName);
Task<bool> ScheduleWorkset(Pool pool, WorkSet workSet); Task<bool> ScheduleWorkset(Pool pool, WorkSet workSet);
IAsyncEnumerable<Pool> GetByClientId(Guid clientId); IAsyncEnumerable<Pool> GetByClientId(Guid clientId);
} }
@ -16,8 +16,8 @@ public class PoolOperations : StatefulOrm<Pool, PoolState>, IPoolOperations {
} }
public async Async.Task<OneFuzzResult<Pool>> GetByName(string poolName) { public async Async.Task<OneFuzzResult<Pool>> GetByName(PoolName poolName) {
var pools = QueryAsync(filter: $"PartitionKey eq '{poolName}'"); var pools = QueryAsync(filter: $"PartitionKey eq '{poolName.String}'");
if (pools == null || await pools.CountAsync() == 0) { if (pools == null || await pools.CountAsync() == 0) {
return OneFuzzResult<Pool>.Error(ErrorCode.INVALID_REQUEST, "unable to find pool"); return OneFuzzResult<Pool>.Error(ErrorCode.INVALID_REQUEST, "unable to find pool");

View File

@ -6,7 +6,7 @@ namespace Microsoft.OneFuzz.Service;
public interface IScalesetOperations : IOrm<Scaleset> { public interface IScalesetOperations : IOrm<Scaleset> {
IAsyncEnumerable<Scaleset> Search(); IAsyncEnumerable<Scaleset> Search();
public IAsyncEnumerable<Scaleset?> SearchByPool(string poolName); public IAsyncEnumerable<Scaleset?> SearchByPool(PoolName poolName);
public Async.Task UpdateConfigs(Scaleset scaleSet); public Async.Task UpdateConfigs(Scaleset scaleSet);
@ -29,8 +29,8 @@ public class ScalesetOperations : StatefulOrm<Scaleset, ScalesetState>, IScalese
return QueryAsync(); return QueryAsync();
} }
public IAsyncEnumerable<Scaleset> SearchByPool(string poolName) { public IAsyncEnumerable<Scaleset> SearchByPool(PoolName poolName) {
return QueryAsync(filter: $"pool_name eq '{poolName}'"); return QueryAsync(filter: $"PartitionKey eq '{poolName}'");
} }

View File

@ -182,7 +182,7 @@ public class Scheduler : IScheduler {
return (bucketConfig, workUnit); return (bucketConfig, workUnit);
} }
record struct BucketId(Os os, Guid jobId, (string, string)? vm, string? pool, string setupContainer, bool? reboot, Guid? unique); record struct BucketId(Os os, Guid jobId, (string, string)? vm, PoolName? pool, string setupContainer, bool? reboot, Guid? unique);
private ILookup<BucketId, Task> BucketTasks(IEnumerable<Task> tasks) { private ILookup<BucketId, Task> BucketTasks(IEnumerable<Task> tasks) {
@ -205,7 +205,7 @@ public class Scheduler : IScheduler {
} }
// check for multiple VMs for 1.0.0 and later tasks // check for multiple VMs for 1.0.0 and later tasks
string? pool = task.Config.Pool?.PoolName; var pool = task.Config.Pool?.PoolName;
if ((task.Config.Pool?.Count ?? 0) > 1) { if ((task.Config.Pool?.Count ?? 0) > 1) {
unique = Guid.NewGuid(); unique = Guid.NewGuid();
} }
@ -219,5 +219,3 @@ public class Scheduler : IScheduler {
}); });
} }
} }

View File

@ -80,7 +80,7 @@ public class VmssOperations : IVmssOperations {
} }
var _ = await res.UpdateAsync(WaitUntil.Started, patch); var _ = await res.UpdateAsync(WaitUntil.Started, patch);
_log.Info($"VM extensions updated: {name}"); _log.Info($"VM extensions updated: {name}");
return OneFuzzResultVoid.Ok(); return OneFuzzResultVoid.Ok;
} else { } else {
return OneFuzzResultVoid.Error(canUpdate.ErrorV); return OneFuzzResultVoid.Error(canUpdate.ErrorV);
@ -169,13 +169,13 @@ public class VmssOperations : IVmssOperations {
_log.WithHttpStatus((r.GetRawResponse().Status, r.GetRawResponse().ReasonPhrase)).Error(msg); _log.WithHttpStatus((r.GetRawResponse().Status, r.GetRawResponse().ReasonPhrase)).Error(msg);
return OneFuzzResultVoid.Error(ErrorCode.UNABLE_TO_UPDATE, msg); return OneFuzzResultVoid.Error(ErrorCode.UNABLE_TO_UPDATE, msg);
} else { } else {
return OneFuzzResultVoid.Ok(); return OneFuzzResultVoid.Ok;
} }
} catch (Exception ex) when (ex is RequestFailedException || ex is CloudException) { } catch (Exception ex) when (ex is RequestFailedException || ex is CloudException) {
if (ex.Message.Contains(INSTANCE_NOT_FOUND) && protectFromScaleIn == false) { if (ex.Message.Contains(INSTANCE_NOT_FOUND) && protectFromScaleIn == false) {
_log.Info($"Tried to remove scale in protection on node {name} {vmId} but instance no longer exists"); _log.Info($"Tried to remove scale in protection on node {name} {vmId} but instance no longer exists");
return OneFuzzResultVoid.Ok(); return OneFuzzResultVoid.Ok;
} else { } else {
var msg = $"failed to update scale in protection on vm {vmId} for scaleset {name}"; var msg = $"failed to update scale in protection on vm {vmId} for scaleset {name}";
_log.Exception(ex, msg); _log.Exception(ex, msg);

View File

@ -143,10 +143,9 @@ public class EntityConverter {
}); });
} }
public string ToJsonString<T>(T typedEntity) { public string ToJsonString<T>(T typedEntity) => JsonSerializer.Serialize(typedEntity, _options);
var serialized = JsonSerializer.Serialize(typedEntity, _options);
return serialized; public T? FromJsonString<T>(string value) => JsonSerializer.Deserialize<T>(value, _options);
}
public TableEntity ToTableEntity<T>(T typedEntity) where T : EntityBase { public TableEntity ToTableEntity<T>(T typedEntity) where T : EntityBase {
if (typedEntity == null) { if (typedEntity == null) {
@ -211,8 +210,11 @@ public class EntityConverter {
return Guid.Parse(entity.GetString(ef.kind.ToString())); return Guid.Parse(entity.GetString(ef.kind.ToString()));
else if (ef.type == typeof(int)) else if (ef.type == typeof(int))
return int.Parse(entity.GetString(ef.kind.ToString())); return int.Parse(entity.GetString(ef.kind.ToString()));
else if (ef.type == typeof(PoolName))
// TODO: this should be able to be generic over any ValidatedString
return PoolName.Parse(entity.GetString(ef.kind.ToString()));
else { else {
throw new Exception("invalid "); throw new Exception($"invalid partition or row key type of {info.type} property {name}: {ef.type}");
} }
} }
@ -247,7 +249,6 @@ public class EntityConverter {
outputType = typeProvider.GetTypeInfo(v); outputType = typeProvider.GetTypeInfo(v);
} }
if (objType == typeof(string)) { if (objType == typeof(string)) {
var value = entity.GetString(fieldName); var value = entity.GetString(fieldName);
if (value.StartsWith('[') || value.StartsWith('{') || value == "null") { if (value.StartsWith('[') || value.StartsWith('{') || value == "null") {
@ -283,6 +284,3 @@ public class EntityConverter {
} }
} }

View File

@ -11,6 +11,9 @@ namespace ApiService.OneFuzzLib.Orm {
public static string PartitionKey(string partitionKey) public static string PartitionKey(string partitionKey)
=> TableClient.CreateQueryFilter($"PartitionKey eq {partitionKey}"); => TableClient.CreateQueryFilter($"PartitionKey eq {partitionKey}");
public static string PartitionKey(Guid partitionKey)
=> TableClient.CreateQueryFilter($"PartitionKey eq {partitionKey}");
public static string RowKey(string rowKey) public static string RowKey(string rowKey)
=> TableClient.CreateQueryFilter($"RowKey eq {rowKey}"); => TableClient.CreateQueryFilter($"RowKey eq {rowKey}");

View File

@ -24,6 +24,9 @@ public sealed class TestContext : IOnefuzzContext {
NodeTasksOperations = new NodeTasksOperations(logTracer, this); NodeTasksOperations = new NodeTasksOperations(logTracer, this);
TaskEventOperations = new TaskEventOperations(logTracer, this); TaskEventOperations = new TaskEventOperations(logTracer, this);
NodeMessageOperations = new NodeMessageOperations(logTracer, this); NodeMessageOperations = new NodeMessageOperations(logTracer, this);
ConfigOperations = new ConfigOperations(logTracer, this);
UserCredentials = new UserCredentials(logTracer, ConfigOperations);
} }
public TestEvents Events { get; set; } = new(); public TestEvents Events { get; set; } = new();
@ -36,6 +39,7 @@ public sealed class TestContext : IOnefuzzContext {
Node n => NodeOperations.Insert(n), Node n => NodeOperations.Insert(n),
Job j => JobOperations.Insert(j), Job j => JobOperations.Insert(j),
NodeTasks nt => NodeTasksOperations.Insert(nt), NodeTasks nt => NodeTasksOperations.Insert(nt),
InstanceConfig ic => ConfigOperations.Insert(ic),
_ => throw new NotImplementedException($"Need to add an TestContext.InsertAll case for {x.GetType()} entities"), _ => throw new NotImplementedException($"Need to add an TestContext.InsertAll case for {x.GetType()} entities"),
})); }));
@ -48,6 +52,7 @@ public sealed class TestContext : IOnefuzzContext {
public IStorage Storage { get; } public IStorage Storage { get; }
public ICreds Creds { get; } public ICreds Creds { get; }
public IContainers Containers { get; } public IContainers Containers { get; }
public IUserCredentials UserCredentials { get; set; }
public IRequestHandling RequestHandling { get; } public IRequestHandling RequestHandling { get; }
@ -57,12 +62,12 @@ public sealed class TestContext : IOnefuzzContext {
public INodeTasksOperations NodeTasksOperations { get; } public INodeTasksOperations NodeTasksOperations { get; }
public ITaskEventOperations TaskEventOperations { get; } public ITaskEventOperations TaskEventOperations { get; }
public INodeMessageOperations NodeMessageOperations { get; } public INodeMessageOperations NodeMessageOperations { get; }
public IConfigOperations ConfigOperations { get; }
// -- Remainder not implemented -- // -- Remainder not implemented --
public IConfig Config => throw new System.NotImplementedException(); public IConfig Config => throw new System.NotImplementedException();
public IConfigOperations ConfigOperations => throw new System.NotImplementedException();
public IDiskOperations DiskOperations => throw new System.NotImplementedException(); public IDiskOperations DiskOperations => throw new System.NotImplementedException();
@ -92,8 +97,6 @@ public sealed class TestContext : IOnefuzzContext {
public ISecretsOperations SecretsOperations => throw new System.NotImplementedException(); public ISecretsOperations SecretsOperations => throw new System.NotImplementedException();
public IUserCredentials UserCredentials => throw new System.NotImplementedException();
public IVmOperations VmOperations => throw new System.NotImplementedException(); public IVmOperations VmOperations => throw new System.NotImplementedException();
public IVmssOperations VmssOperations => throw new System.NotImplementedException(); public IVmssOperations VmssOperations => throw new System.NotImplementedException();

View File

@ -11,16 +11,16 @@ enum RequestType {
Agent, Agent,
} }
sealed class TestEndpointAuthorization : IEndpointAuthorization { sealed class TestEndpointAuthorization : EndpointAuthorization {
private readonly RequestType _type; private readonly RequestType _type;
private readonly IOnefuzzContext _context; private readonly IOnefuzzContext _context;
public TestEndpointAuthorization(RequestType type, IOnefuzzContext context) { public TestEndpointAuthorization(RequestType type, ILogTracer log, IOnefuzzContext context) : base(context, log) {
_type = type; _type = type;
_context = context; _context = context;
} }
public Task<HttpResponseData> CallIf( public override Task<HttpResponseData> CallIf(
HttpRequestData req, HttpRequestData req,
Func<HttpRequestData, Task<HttpResponseData>> method, Func<HttpRequestData, Task<HttpResponseData>> method,
bool allowUser = false, bool allowUser = false,

View File

@ -19,6 +19,8 @@ public sealed class TestServiceConfiguration : IServiceConfig {
public string? ApplicationInsightsInstrumentationKey { get; set; } = "TestAppInsightsInstrumentationKey"; public string? ApplicationInsightsInstrumentationKey { get; set; } = "TestAppInsightsInstrumentationKey";
public string? OneFuzzInstanceName => "UnitTestInstance";
// -- Remainder not implemented -- // -- Remainder not implemented --
public LogDestination[] LogDestinations { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); } public LogDestination[] LogDestinations { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); }
@ -42,8 +44,6 @@ public sealed class TestServiceConfiguration : IServiceConfig {
public string? OneFuzzInstance => throw new System.NotImplementedException(); public string? OneFuzzInstance => throw new System.NotImplementedException();
public string? OneFuzzInstanceName => throw new System.NotImplementedException();
public string? OneFuzzKeyvault => throw new System.NotImplementedException(); public string? OneFuzzKeyvault => throw new System.NotImplementedException();
public string? OneFuzzMonitor => throw new System.NotImplementedException(); public string? OneFuzzMonitor => throw new System.NotImplementedException();

View File

@ -0,0 +1,19 @@
using System.Threading.Tasks;
using Microsoft.Azure.Functions.Worker.Http;
using Microsoft.OneFuzz.Service;
using Async = System.Threading.Tasks;
namespace Tests.Fakes;
sealed class TestUserCredentials : UserCredentials {
private readonly OneFuzzResult<UserInfo> _tokenResult;
public TestUserCredentials(ILogTracer log, IConfigOperations instanceConfig, OneFuzzResult<UserInfo> tokenResult)
: base(log, instanceConfig) {
_tokenResult = tokenResult;
}
public override Task<OneFuzzResult<UserInfo>> ParseJwtToken(HttpRequestData req) => Async.Task.FromResult(_tokenResult);
}

View File

@ -29,7 +29,7 @@ public abstract class AgentEventsTestsBase : FunctionTestBase {
readonly Guid jobId = Guid.NewGuid(); readonly Guid jobId = Guid.NewGuid();
readonly Guid taskId = Guid.NewGuid(); readonly Guid taskId = Guid.NewGuid();
readonly Guid machineId = Guid.NewGuid(); readonly Guid machineId = Guid.NewGuid();
readonly string poolName = $"pool-{Guid.NewGuid()}"; readonly PoolName poolName = PoolName.Parse($"pool-{Guid.NewGuid()}");
readonly Guid poolId = Guid.NewGuid(); readonly Guid poolId = Guid.NewGuid();
readonly string poolVersion = $"version-{Guid.NewGuid()}"; readonly string poolVersion = $"version-{Guid.NewGuid()}";

View File

@ -27,7 +27,7 @@ public abstract class InfoTestBase : FunctionTestBase {
[Fact] [Fact]
public async Async.Task TestInfo_WithoutAuthorization_IsRejected() { public async Async.Task TestInfo_WithoutAuthorization_IsRejected() {
var auth = new TestEndpointAuthorization(RequestType.NoAuthorization, Context); var auth = new TestEndpointAuthorization(RequestType.NoAuthorization, Logger, Context);
var func = new Info(auth, Context); var func = new Info(auth, Context);
var result = await func.Run(TestHttpRequestData.Empty("GET")); var result = await func.Run(TestHttpRequestData.Empty("GET"));
@ -36,7 +36,7 @@ public abstract class InfoTestBase : FunctionTestBase {
[Fact] [Fact]
public async Async.Task TestInfo_WithAgentCredentials_IsRejected() { public async Async.Task TestInfo_WithAgentCredentials_IsRejected() {
var auth = new TestEndpointAuthorization(RequestType.Agent, Context); var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context);
var func = new Info(auth, Context); var func = new Info(auth, Context);
var result = await func.Run(TestHttpRequestData.Empty("GET")); var result = await func.Run(TestHttpRequestData.Empty("GET"));
@ -52,7 +52,7 @@ public abstract class InfoTestBase : FunctionTestBase {
await containerClient.CreateAsync(); await containerClient.CreateAsync();
await containerClient.GetBlobClient("instance_id").UploadAsync(new BinaryData(instanceId)); await containerClient.GetBlobClient("instance_id").UploadAsync(new BinaryData(instanceId));
var auth = new TestEndpointAuthorization(RequestType.User, Context); var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
var func = new Info(auth, Context); var func = new Info(auth, Context);
var result = await func.Run(TestHttpRequestData.Empty("GET")); var result = await func.Run(TestHttpRequestData.Empty("GET"));

View File

@ -0,0 +1,291 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using Microsoft.OneFuzz.Service;
using Tests.Fakes;
using Xunit;
using Xunit.Abstractions;
using Async = System.Threading.Tasks;
namespace Tests.Functions;
[Trait("Category", "Integration")]
public class AzureStorageNodeTest : NodeTestBase {
public AzureStorageNodeTest(ITestOutputHelper output)
: base(output, Integration.AzureStorage.FromEnvironment()) { }
}
public class AzuriteNodeTest : NodeTestBase {
public AzuriteNodeTest(ITestOutputHelper output)
: base(output, new Integration.AzuriteStorage()) { }
}
public abstract class NodeTestBase : FunctionTestBase {
public NodeTestBase(ITestOutputHelper output, IStorage storage)
: base(output, storage) { }
private readonly Guid _machineId = Guid.NewGuid();
private readonly Guid _scalesetId = Guid.NewGuid();
private readonly PoolName _poolName = PoolName.Parse($"pool-{Guid.NewGuid()}");
private readonly string _version = Guid.NewGuid().ToString();
[Fact]
public async Async.Task Search_SpecificNode_NotFound_ReturnsNotFound() {
var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
var req = new NodeSearch(MachineId: _machineId);
var func = new NodeFunction(Logger, auth, Context);
var result = await func.Run(TestHttpRequestData.FromJson("GET", req));
Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode);
}
[Fact]
public async Async.Task Search_SpecificNode_Found_ReturnsOk() {
await Context.InsertAll(
new Node(_poolName, _machineId, null, _version));
var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
var req = new NodeSearch(MachineId: _machineId);
var func = new NodeFunction(Logger, auth, Context);
var result = await func.Run(TestHttpRequestData.FromJson("GET", req));
Assert.Equal(HttpStatusCode.OK, result.StatusCode);
// make sure we got the data from the table
var deserialized = BodyAs<NodeSearchResult>(result);
Assert.Equal(_version, deserialized.Version);
}
[Fact]
public async Async.Task Search_MultipleNodes_CanFindNone() {
var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
var req = new NodeSearch();
var func = new NodeFunction(Logger, auth, Context);
var result = await func.Run(TestHttpRequestData.FromJson("GET", req));
Assert.Equal(HttpStatusCode.OK, result.StatusCode);
Assert.Equal(0, result.Body.Length);
}
[Fact]
public async Async.Task Search_MultipleNodes_ByPoolName() {
await Context.InsertAll(
new Node(PoolName.Parse("otherPool"), Guid.NewGuid(), null, _version),
new Node(_poolName, Guid.NewGuid(), null, _version),
new Node(_poolName, Guid.NewGuid(), null, _version));
var req = new NodeSearch(PoolName: _poolName);
var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
var func = new NodeFunction(Logger, auth, Context);
var result = await func.Run(TestHttpRequestData.FromJson("GET", req));
Assert.Equal(HttpStatusCode.OK, result.StatusCode);
// make sure we got the data from the table
var deserialized = BodyAs<NodeSearchResult[]>(result);
Assert.Equal(2, deserialized.Length);
}
[Fact]
public async Async.Task Search_MultipleNodes_ByScalesetId() {
await Context.InsertAll(
new Node(_poolName, Guid.NewGuid(), null, _version, ScalesetId: _scalesetId),
new Node(_poolName, Guid.NewGuid(), null, _version, ScalesetId: _scalesetId),
new Node(_poolName, Guid.NewGuid(), null, _version));
var req = new NodeSearch(ScalesetId: _scalesetId);
var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
var func = new NodeFunction(Logger, auth, Context);
var result = await func.Run(TestHttpRequestData.FromJson("GET", req));
Assert.Equal(HttpStatusCode.OK, result.StatusCode);
// make sure we got the data from the table
var deserialized = BodyAs<NodeSearchResult[]>(result);
Assert.Equal(2, deserialized.Length);
}
[Fact]
public async Async.Task Search_MultipleNodes_ByState() {
await Context.InsertAll(
new Node(_poolName, Guid.NewGuid(), null, _version, State: NodeState.Busy),
new Node(_poolName, Guid.NewGuid(), null, _version, State: NodeState.Busy),
new Node(_poolName, Guid.NewGuid(), null, _version));
var req = new NodeSearch(State: new List<NodeState> { NodeState.Busy });
var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
var func = new NodeFunction(Logger, auth, Context);
var result = await func.Run(TestHttpRequestData.FromJson("GET", req));
Assert.Equal(HttpStatusCode.OK, result.StatusCode);
// make sure we got the data from the table
var deserialized = BodyAs<NodeSearchResult[]>(result);
Assert.Equal(2, deserialized.Length);
}
[Fact]
public async Async.Task Search_MultipleNodes_ByMultipleStates() {
await Context.InsertAll(
new Node(_poolName, Guid.NewGuid(), null, _version, State: NodeState.Free),
new Node(_poolName, Guid.NewGuid(), null, _version, State: NodeState.Busy),
new Node(_poolName, Guid.NewGuid(), null, _version, State: NodeState.Busy),
new Node(_poolName, Guid.NewGuid(), null, _version));
var req = new NodeSearch(State: new List<NodeState> { NodeState.Free, NodeState.Busy });
var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
var func = new NodeFunction(Logger, auth, Context);
var result = await func.Run(TestHttpRequestData.FromJson("GET", req));
Assert.Equal(HttpStatusCode.OK, result.StatusCode);
// make sure we got the data from the table
var deserialized = BodyAs<NodeSearchResult[]>(result);
Assert.Equal(3, deserialized.Length);
}
[Theory]
[InlineData("PATCH")]
[InlineData("POST")]
[InlineData("DELETE")]
public async Async.Task RequiresAdmin(string method) {
// config must be found
await Context.InsertAll(
new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!));
// must be a user to auth
var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
// override the found user credentials
var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: Guid.NewGuid(), "upn");
Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult<UserInfo>.Ok(userInfo));
var req = new NodeGet(MachineId: _machineId);
var func = new NodeFunction(Logger, auth, Context);
var result = await func.Run(TestHttpRequestData.FromJson(method, req));
Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode);
var err = BodyAs<Error>(result);
Assert.Equal(ErrorCode.UNAUTHORIZED, err.Code);
Assert.Contains("pool modification disabled", err.Errors?.Single());
}
[Theory]
[InlineData("PATCH")]
[InlineData("POST")]
[InlineData("DELETE")]
public async Async.Task RequiresAdmin_CanBeDisabled(string method) {
// disable requiring admin privileges
await Context.InsertAll(
new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) {
RequireAdminPrivileges = false
});
// must be a user to auth
var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
// override the found user credentials
var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: Guid.NewGuid(), "upn");
Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult<UserInfo>.Ok(userInfo));
var req = new NodeGet(MachineId: _machineId);
var func = new NodeFunction(Logger, auth, Context);
var result = await func.Run(TestHttpRequestData.FromJson(method, req));
// we will fail with BadRequest but due to not being able to find the Node,
// not because of UNAUTHORIZED
Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode);
Assert.Equal(ErrorCode.UNABLE_TO_FIND, BodyAs<Error>(result).Code);
}
[Theory]
[InlineData("PATCH")]
[InlineData("POST")]
[InlineData("DELETE")]
public async Async.Task UserCanBeAdmin(string method) {
var userObjectId = Guid.NewGuid();
// config specifies that user is admin
await Context.InsertAll(
new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) {
Admins = new[] { userObjectId }
});
// must be a user to auth
var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
// override the found user credentials
var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: userObjectId, "upn");
Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult<UserInfo>.Ok(userInfo));
var req = new NodeGet(MachineId: _machineId);
var func = new NodeFunction(Logger, auth, Context);
var result = await func.Run(TestHttpRequestData.FromJson(method, req));
// we will fail with BadRequest but due to not being able to find the Node,
// not because of UNAUTHORIZED
Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode);
Assert.Equal(ErrorCode.UNABLE_TO_FIND, BodyAs<Error>(result).Code);
}
[Theory]
[InlineData("PATCH")]
[InlineData("POST")]
[InlineData("DELETE")]
public async Async.Task EnablingAdminForAnotherUserDoesNotPermitThisUser(string method) {
var userObjectId = Guid.NewGuid();
var otherObjectId = Guid.NewGuid();
// config specifies that a different user is admin
await Context.InsertAll(
new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) {
Admins = new[] { otherObjectId }
});
// must be a user to auth
var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
// override the found user credentials
var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: userObjectId, "upn");
Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult<UserInfo>.Ok(userInfo));
var req = new NodeGet(MachineId: _machineId);
var func = new NodeFunction(Logger, auth, Context);
var result = await func.Run(TestHttpRequestData.FromJson(method, req));
Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode);
var err = BodyAs<Error>(result);
Assert.Equal(ErrorCode.UNAUTHORIZED, err.Code);
Assert.Contains("not authorized to manage pools", err.Errors?.Single());
}
[Theory]
[InlineData("PATCH")]
[InlineData("POST")]
[InlineData("DELETE")]
public async Async.Task CanPerformOperation(string method) {
// disable requiring admin privileges
await Context.InsertAll(
new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) {
RequireAdminPrivileges = false
},
new Node(_poolName, _machineId, null, _version));
// must be a user to auth
var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
// override the found user credentials
var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: Guid.NewGuid(), "upn");
Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult<UserInfo>.Ok(userInfo));
// all of these operations use NodeGet
var req = new NodeGet(MachineId: _machineId);
var func = new NodeFunction(Logger, auth, Context);
var result = await func.Run(TestHttpRequestData.FromJson(method, req));
Assert.Equal(HttpStatusCode.OK, result.StatusCode);
}
}

View File

@ -7,6 +7,7 @@ using Azure.Storage;
using Azure.Storage.Blobs; using Azure.Storage.Blobs;
using Microsoft.Azure.Functions.Worker.Http; using Microsoft.Azure.Functions.Worker.Http;
using Microsoft.OneFuzz.Service; using Microsoft.OneFuzz.Service;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
using Tests.Fakes; using Tests.Fakes;
using Xunit.Abstractions; using Xunit.Abstractions;
@ -60,6 +61,9 @@ public abstract class FunctionTestBase : IDisposable {
return sr.ReadToEnd(); return sr.ReadToEnd();
} }
protected static T BodyAs<T>(HttpResponseData data)
=> new EntityConverter().FromJsonString<T>(BodyAsString(data)) ?? throw new Exception($"unable to deserialize body as {typeof(T)}");
public void Dispose() { public void Dispose() {
var (accountName, accountKey) = _storage.GetStorageAccountNameAndKey("").Result; // sync for test impls var (accountName, accountKey) = _storage.GetStorageAccountNameAndKey("").Result; // sync for test impls
if (accountName is not null && accountKey is not null) { if (accountName is not null && accountKey is not null) {
@ -72,6 +76,7 @@ public abstract class FunctionTestBase : IDisposable {
new StorageSharedKeyCredential(accountName, accountKey)); new StorageSharedKeyCredential(accountName, accountKey));
} }
} }
private void CleanupBlobs(Uri endpoint, StorageSharedKeyCredential creds) { private void CleanupBlobs(Uri endpoint, StorageSharedKeyCredential creds) {
var blobClient = new BlobServiceClient(endpoint, creds); var blobClient = new BlobServiceClient(endpoint, creds);

View File

@ -74,21 +74,26 @@ namespace Tests {
); );
} }
public static Gen<Node> Node() { public static Gen<PoolName> PoolNameGen { get; }
return Arb.Generate<Tuple<Tuple<DateTimeOffset?, string, Guid?, Guid, NodeState>, Tuple<Guid?, DateTimeOffset, string, bool, bool, bool>>>().Select( = from name in Arb.Generate<NonEmptyString>()
arg => new Node( where PoolName.TryParse(name.Get, out _)
select PoolName.Parse(name.Get);
public static Gen<Node> Node { get; }
= from arg in Arb.Generate<Tuple<Tuple<DateTimeOffset?, Guid?, Guid, NodeState>, Tuple<Guid?, DateTimeOffset, string, bool, bool, bool>>>()
from poolName in PoolNameGen
select new Node(
InitializedAt: arg.Item1.Item1, InitializedAt: arg.Item1.Item1,
PoolName: arg.Item1.Item2, PoolName: poolName,
PoolId: arg.Item1.Item3, PoolId: arg.Item1.Item3,
MachineId: arg.Item1.Item4, MachineId: arg.Item1.Item3,
State: arg.Item1.Item5, State: arg.Item1.Item4,
ScalesetId: arg.Item2.Item1, ScalesetId: arg.Item2.Item1,
Heartbeat: arg.Item2.Item2, Heartbeat: arg.Item2.Item2,
Version: arg.Item2.Item3, Version: arg.Item2.Item3,
ReimageRequested: arg.Item2.Item4, ReimageRequested: arg.Item2.Item4,
DeleteRequested: arg.Item2.Item5, DeleteRequested: arg.Item2.Item5,
DebugKeepNode: arg.Item2.Item6)); DebugKeepNode: arg.Item2.Item6);
}
public static Gen<ProxyForward> ProxyForward() { public static Gen<ProxyForward> ProxyForward() {
return Arb.Generate<Tuple<Tuple<string, long, Guid, Guid, Guid?, long>, Tuple<IPv4Address, DateTimeOffset>>>().Select( return Arb.Generate<Tuple<Tuple<string, long, Guid, Guid, Guid?, long>, Tuple<IPv4Address, DateTimeOffset>>>().Select(
@ -200,20 +205,20 @@ namespace Tests {
) )
); );
} }
public static Gen<Scaleset> Scaleset() { public static Gen<Scaleset> Scaleset { get; }
return Arb.Generate<Tuple< = from arg in Arb.Generate<Tuple<
Tuple<string, Guid, ScalesetState, Authentication?, string, string, string>, Tuple<Guid, ScalesetState, Authentication?, string, string, string>,
Tuple<int, bool, bool, bool, Error?, List<ScalesetNodeState>, Guid?>, Tuple<int, bool, bool, bool, Error?, List<ScalesetNodeState>, Guid?>,
Tuple<Guid?, Dictionary<string, string>>>>().Select( Tuple<Guid?, Dictionary<string, string>>>>()
arg => from poolName in PoolNameGen
new Scaleset( select new Scaleset(
PoolName: arg.Item1.Item1, PoolName: poolName,
ScalesetId: arg.Item1.Item2, ScalesetId: arg.Item1.Item1,
State: arg.Item1.Item3, State: arg.Item1.Item2,
Auth: arg.Item1.Item4, Auth: arg.Item1.Item3,
VmSku: arg.Item1.Item5, VmSku: arg.Item1.Item4,
Image: arg.Item1.Item6, Image: arg.Item1.Item5,
Region: arg.Item1.Item7, Region: arg.Item1.Item6,
Size: arg.Item2.Item1, Size: arg.Item2.Item1,
SpotInstance: arg.Item2.Item2, SpotInstance: arg.Item2.Item2,
@ -224,11 +229,7 @@ namespace Tests {
ClientId: arg.Item2.Item7, ClientId: arg.Item2.Item7,
ClientObjectId: arg.Item3.Item1, ClientObjectId: arg.Item3.Item1,
Tags: arg.Item3.Item2 Tags: arg.Item3.Item2);
)
);
}
public static Gen<Webhook> Webhook() { public static Gen<Webhook> Webhook() {
return Arb.Generate<Tuple<Guid, string, Uri?, List<EventType>, string, WebhookMessageFormat>>().Select( return Arb.Generate<Tuple<Guid, string, Uri?, List<EventType>, string, WebhookMessageFormat>>().Select(
@ -331,7 +332,9 @@ namespace Tests {
public class OrmArb { public class OrmArb {
public static Arbitrary<Version> Vresion() { public static Arbitrary<PoolName> PoolName { get; } = OrmGenerators.PoolNameGen.ToArbitrary();
public static Arbitrary<Version> Version() {
return Arb.From(OrmGenerators.Version()); return Arb.From(OrmGenerators.Version());
} }
@ -348,7 +351,7 @@ namespace Tests {
} }
public static Arbitrary<Node> Node() { public static Arbitrary<Node> Node() {
return Arb.From(OrmGenerators.Node()); return Arb.From(OrmGenerators.Node);
} }
public static Arbitrary<ProxyForward> ProxyForward() { public static Arbitrary<ProxyForward> ProxyForward() {
@ -383,9 +386,8 @@ namespace Tests {
return Arb.From(OrmGenerators.Task()); return Arb.From(OrmGenerators.Task());
} }
public static Arbitrary<Scaleset> Scaleset() { public static Arbitrary<Scaleset> Scaleset()
return Arb.From(OrmGenerators.Scaleset()); => Arb.From(OrmGenerators.Scaleset);
}
public static Arbitrary<Webhook> Webhook() { public static Arbitrary<Webhook> Webhook() {
return Arb.From(OrmGenerators.Webhook()); return Arb.From(OrmGenerators.Webhook());

View File

@ -234,7 +234,7 @@ namespace Tests {
[Fact] [Fact]
public void TestEventSerialization() { public void TestEventSerialization() {
var expectedEvent = new EventMessage(Guid.NewGuid(), EventType.NodeHeartbeat, new EventNodeHeartbeat(Guid.NewGuid(), Guid.NewGuid(), "test Poool"), Guid.NewGuid(), "test"); var expectedEvent = new EventMessage(Guid.NewGuid(), EventType.NodeHeartbeat, new EventNodeHeartbeat(Guid.NewGuid(), Guid.NewGuid(), PoolName.Parse("test-Poool")), Guid.NewGuid(), "test");
var serialized = JsonSerializer.Serialize(expectedEvent, EntityConverter.GetJsonSerializerOptions()); var serialized = JsonSerializer.Serialize(expectedEvent, EntityConverter.GetJsonSerializerOptions());
var actualEvent = JsonSerializer.Deserialize<EventMessage>(serialized, EntityConverter.GetJsonSerializerOptions()); var actualEvent = JsonSerializer.Deserialize<EventMessage>(serialized, EntityConverter.GetJsonSerializerOptions());
Assert.Equal(expectedEvent, actualEvent); Assert.Equal(expectedEvent, actualEvent);

View File

@ -0,0 +1,28 @@
using System.Text.Json;
using Microsoft.OneFuzz.Service;
using Xunit;
namespace Tests;
public class ValidatedStringTests {
record ThingContainingPoolName(PoolName PoolName);
[Fact]
public void PoolNameValidatesOnDeserialization() {
var ex = Assert.Throws<JsonException>(() => JsonSerializer.Deserialize<ThingContainingPoolName>("{ \"PoolName\": \"is-not!-a-pool\" }"));
Assert.Equal("unable to parse input as a PoolName", ex.Message);
}
[Fact]
public void PoolNameDeserializesFromString() {
var result = JsonSerializer.Deserialize<ThingContainingPoolName>("{ \"PoolName\": \"is-a-pool\" }");
Assert.Equal("is-a-pool", result?.PoolName.String);
}
[Fact]
public void PoolNameSerializesToString() {
var result = JsonSerializer.Serialize(new ThingContainingPoolName(PoolName.Parse("is-a-pool")));
Assert.Equal("{\"PoolName\":\"is-a-pool\"}", result);
}
}