diff --git a/src/ApiService/ApiService/Functions/Jobs.cs b/src/ApiService/ApiService/Functions/Jobs.cs index 1c739e65a..dd8e61e2b 100644 --- a/src/ApiService/ApiService/Functions/Jobs.cs +++ b/src/ApiService/ApiService/Functions/Jobs.cs @@ -23,7 +23,7 @@ public class Jobs { }); private async Task Post(HttpRequestData req) { - var request = await RequestHandling.ParseRequest(req); + var request = await RequestHandling.ParseRequest(req); if (!request.IsOk) { return await _context.RequestHandling.NotOk(req, request.ErrorV, "jobs create"); } @@ -33,10 +33,18 @@ public class Jobs { return await _context.RequestHandling.NotOk(req, userInfo.ErrorV, "jobs create"); } + var create = request.OkV; + var cfg = new JobConfig( + Build: create.Build, + Duration: create.Duration, + Logs: create.Logs, + Name: create.Name, + Project: create.Project); + var job = new Job( JobId: Guid.NewGuid(), State: JobState.Init, - Config: request.OkV) { + Config: cfg) { UserInfo = userInfo.OkV, }; diff --git a/src/ApiService/ApiService/Functions/ReproVmss.cs b/src/ApiService/ApiService/Functions/ReproVmss.cs index 8160b7b22..4b99877a7 100644 --- a/src/ApiService/ApiService/Functions/ReproVmss.cs +++ b/src/ApiService/ApiService/Functions/ReproVmss.cs @@ -51,7 +51,7 @@ public class ReproVmss { private async Async.Task Post(HttpRequestData req) { - var request = await RequestHandling.ParseRequest(req); + var request = await RequestHandling.ParseRequest(req); if (!request.IsOk) { return await _context.RequestHandling.NotOk( req, @@ -67,7 +67,13 @@ public class ReproVmss { "repro_vm create"); } - var vm = await _context.ReproOperations.Create(request.OkV, userInfo.OkV); + var create = request.OkV; + var cfg = new ReproConfig( + Container: create.Container, + Path: create.Path, + Duration: create.Duration); + + var vm = await _context.ReproOperations.Create(cfg, userInfo.OkV); if (!vm.IsOk) { return await _context.RequestHandling.NotOk( req, diff --git a/src/ApiService/ApiService/Functions/Tasks.cs b/src/ApiService/ApiService/Functions/Tasks.cs index ba254679e..be5c8b42d 100644 --- a/src/ApiService/ApiService/Functions/Tasks.cs +++ b/src/ApiService/ApiService/Functions/Tasks.cs @@ -72,7 +72,7 @@ public class Tasks { private async Async.Task Post(HttpRequestData req) { - var request = await RequestHandling.ParseRequest(req); + var request = await RequestHandling.ParseRequest(req); if (!request.IsOk) { return await _context.RequestHandling.NotOk( req, @@ -85,7 +85,19 @@ public class Tasks { return await _context.RequestHandling.NotOk(req, userInfo.ErrorV, "task create"); } - var checkConfig = await _context.Config.CheckConfig(request.OkV); + var create = request.OkV; + var cfg = new TaskConfig( + JobId: create.JobId, + PrereqTasks: create.PrereqTasks, + Task: create.Task, + Vm: null, + Pool: create.Pool, + Containers: create.Containers, + Tags: create.Tags, + Debug: create.Debug, + Colocate: create.Colocate); + + var checkConfig = await _context.Config.CheckConfig(cfg); if (!checkConfig.IsOk) { return await _context.RequestHandling.NotOk( req, @@ -99,23 +111,23 @@ public class Tasks { return response; } - var job = await _context.JobOperations.Get(request.OkV.JobId); + var job = await _context.JobOperations.Get(cfg.JobId); if (job == null) { return await _context.RequestHandling.NotOk( req, new Error(ErrorCode.INVALID_REQUEST, new[] { "unable to find job" }), - request.OkV.JobId.ToString()); + cfg.JobId.ToString()); } if (job.State != JobState.Enabled && job.State != JobState.Init) { return await _context.RequestHandling.NotOk( req, new Error(ErrorCode.UNABLE_TO_ADD_TASK_TO_JOB, new[] { $"unable to add a job in state {job.State}" }), - request.OkV.JobId.ToString()); + cfg.JobId.ToString()); } - if (request.OkV.PrereqTasks != null) { - foreach (var taskId in request.OkV.PrereqTasks) { + if (cfg.PrereqTasks != null) { + foreach (var taskId in cfg.PrereqTasks) { var prereq = await _context.TaskOperations.GetByTaskId(taskId); if (prereq == null) { @@ -127,7 +139,7 @@ public class Tasks { } } - var task = await _context.TaskOperations.Create(request.OkV, request.OkV.JobId, userInfo.OkV); + var task = await _context.TaskOperations.Create(cfg, cfg.JobId, userInfo.OkV); if (!task.IsOk) { return await _context.RequestHandling.NotOk( diff --git a/src/ApiService/ApiService/OneFuzzTypes/Model.cs b/src/ApiService/ApiService/OneFuzzTypes/Model.cs index 537800f95..850aa1ddd 100644 --- a/src/ApiService/ApiService/OneFuzzTypes/Model.cs +++ b/src/ApiService/ApiService/OneFuzzTypes/Model.cs @@ -229,20 +229,20 @@ public record TaskConfig( Dictionary? Tags = null, List? Debug = null, bool? Colocate = null - ); +); public record TaskEventSummary( DateTimeOffset? Timestamp, string EventData, string EventType - ); +); public record NodeAssignment( Guid NodeId, Guid? ScalesetId, NodeTaskState State - ); +); public record Task( diff --git a/src/ApiService/ApiService/OneFuzzTypes/Requests.cs b/src/ApiService/ApiService/OneFuzzTypes/Requests.cs index 4b939f210..405ca1463 100644 --- a/src/ApiService/ApiService/OneFuzzTypes/Requests.cs +++ b/src/ApiService/ApiService/OneFuzzTypes/Requests.cs @@ -1,30 +1,34 @@ using System.ComponentModel.DataAnnotations; +using System.Text.Json; using System.Text.Json.Serialization; namespace Microsoft.OneFuzz.Service; -public record BaseRequest(); +public record BaseRequest { + [JsonExtensionData] + public Dictionary? ExtensionData { get; set; } +}; public record CanScheduleRequest( - Guid MachineId, - Guid TaskId + [property: Required] Guid MachineId, + [property: Required] Guid TaskId ) : BaseRequest; public record NodeCommandGet( - Guid MachineId + [property: Required] Guid MachineId ) : BaseRequest; public record NodeCommandDelete( - Guid MachineId, - string MessageId + [property: Required] Guid MachineId, + [property: Required] string MessageId ) : BaseRequest; public record NodeGet( - Guid MachineId + [property: Required] Guid MachineId ) : BaseRequest; public record NodeUpdate( - Guid MachineId, + [property: Required] Guid MachineId, bool? DebugKeepNode ) : BaseRequest; @@ -36,8 +40,8 @@ public record NodeSearch( ) : BaseRequest; public record NodeStateEnvelope( - NodeEventBase Event, - Guid MachineId + [property: Required] NodeEventBase Event, + [property: Required] Guid MachineId ) : BaseRequest; // either NodeEvent or WorkerEvent @@ -59,16 +63,16 @@ public record WorkerEvent( ) : NodeEventBase; public record WorkerRunningEvent( - Guid TaskId); + [property: Required] Guid TaskId); public record WorkerDoneEvent( - Guid TaskId, - ExitStatus ExitStatus, - string Stderr, - string Stdout); + [property: Required] Guid TaskId, + [property: Required] ExitStatus ExitStatus, + [property: Required] string Stderr, + [property: Required] string Stdout); public record NodeStateUpdate( - NodeState State, + [property: Required] NodeState State, [property: JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] NodeStateData? Data = null ) : NodeEventBase; @@ -78,7 +82,7 @@ public record NodeStateUpdate( public abstract record NodeStateData; public record NodeSettingUpEventData( - List Tasks + [property: Required] List Tasks ) : NodeStateData; public record NodeDoneEventData( @@ -101,23 +105,23 @@ public record ExitStatus( bool Success); public record ContainerGet( - Container Name + [property: Required] Container Name ) : BaseRequest; public record ContainerCreate( - Container Name, + [property: Required] Container Name, IDictionary? Metadata = null ) : BaseRequest; public record ContainerDelete( - Container Name, + [property: Required] Container Name, IDictionary? Metadata = null ) : BaseRequest; public record NotificationCreate( - Container Container, - bool ReplaceExisting, - NotificationTemplate Config + [property: Required] Container Container, + [property: Required] bool ReplaceExisting, + [property: Required] NotificationTemplate Config ) : BaseRequest; public record NotificationSearch( @@ -125,58 +129,75 @@ public record NotificationSearch( ) : BaseRequest; public record NotificationGet( - Guid NotificationId + [property: Required] Guid NotificationId ) : BaseRequest; public record JobGet( - Guid JobId -); + [property: Required] Guid JobId +) : BaseRequest; + +public record JobCreate( + [property: Required] string Project, + [property: Required] string Name, + [property: Required] string Build, + [property: Required] long Duration, + string? Logs +) : BaseRequest; public record JobSearch( Guid? JobId = null, List? State = null, List? TaskState = null, bool? WithTasks = null -); +) : BaseRequest; -public record NodeAddSshKeyPost(Guid MachineId, string PublicKey); +public record NodeAddSshKeyPost( + [property: Required] Guid MachineId, + [property: Required] string PublicKey +) : BaseRequest; -public record ReproGet(Guid? VmId); +public record ReproGet(Guid? VmId) : BaseRequest; + +public record ReproCreate( + [property: Required] Container Container, + [property: Required] string Path, + [property: Required] long Duration +) : BaseRequest; public record ProxyGet( Guid? ScalesetId, Guid? MachineId, - int? DstPort); + int? DstPort +) : BaseRequest; public record ProxyCreate( - Guid ScalesetId, - Guid MachineId, - int DstPort, - int Duration -); + [property: Required] Guid ScalesetId, + [property: Required] Guid MachineId, + [property: Required] int DstPort, + [property: Required] int Duration +) : BaseRequest; public record ProxyDelete( - Guid ScalesetId, - Guid MachineId, + [property: Required] Guid ScalesetId, + [property: Required] Guid MachineId, int? DstPort -); +) : BaseRequest; public record ProxyReset( - string Region -); + [property: Required] string Region +) : BaseRequest; public record ScalesetCreate( - PoolName PoolName, - string VmSku, - string Image, + [property: Required] PoolName PoolName, + [property: Required] string VmSku, + [property: Required] string Image, string? Region, - [property: Range(1, long.MaxValue)] - long Size, - bool SpotInstances, - Dictionary Tags, + [property: Range(1, long.MaxValue), Required] long Size, + [property: Required] bool SpotInstances, + [property: Required] Dictionary Tags, bool EphemeralOsDisks = false, AutoScaleOptions? AutoScale = null -); +) : BaseRequest; public record AutoScaleOptions( [property: Range(0, long.MaxValue)] long Min, @@ -192,63 +213,75 @@ public record ScalesetSearch( Guid? ScalesetId = null, List? State = null, bool IncludeAuth = false -); +) : BaseRequest; public record ScalesetStop( - Guid ScalesetId, - bool Now -); + [property: Required] Guid ScalesetId, + [property: Required] bool Now +) : BaseRequest; public record ScalesetUpdate( - Guid ScalesetId, + [property: Required] Guid ScalesetId, [property: Range(1, long.MaxValue)] long? Size -); +) : BaseRequest; -public record TaskGet(Guid TaskId); +public record TaskGet( + [property: Required] Guid TaskId +) : BaseRequest; + +public record TaskCreate( + [property: Required] Guid JobId, + List? PrereqTasks, + [property: Required] TaskDetails Task, + [property: Required] TaskPool Pool, + List? Containers = null, + Dictionary? Tags = null, + List? Debug = null, + bool? Colocate = null +) : BaseRequest; public record TaskSearch( Guid? JobId, Guid? TaskId, - List State); + [property: Required] List State) : BaseRequest; public record PoolSearch( Guid? PoolId = null, PoolName? Name = null, List? State = null -); +) : BaseRequest; public record PoolStop( - PoolName Name, - bool Now -); + [property: Required] PoolName Name, + [property: Required] bool Now +) : BaseRequest; public record PoolCreate( - PoolName Name, - Os Os, - Architecture Arch, - bool Managed, + [property: Required] PoolName Name, + [property: Required] Os Os, + [property: Required] Architecture Arch, + [property: Required] bool Managed, Guid? ClientId = null -); +) : BaseRequest; public record WebhookCreate( - string Name, - Uri Url, - List EventTypes, + [property: Required] string Name, + [property: Required] Uri Url, + [property: Required] List EventTypes, string? SecretToken, WebhookMessageFormat? MessageFormat -); +) : BaseRequest; +public record WebhookSearch(Guid? WebhookId) : BaseRequest; -public record WebhookSearch(Guid? WebhookId); - -public record WebhookGet(Guid WebhookId); +public record WebhookGet([property: Required] Guid WebhookId) : BaseRequest; public record WebhookUpdate( - Guid WebhookId, + [property: Required] Guid WebhookId, string? Name, Uri? Url, List? EventTypes, string? SecretToken, WebhookMessageFormat? MessageFormat -); +) : BaseRequest; diff --git a/src/ApiService/ApiService/onefuzzlib/Request.cs b/src/ApiService/ApiService/onefuzzlib/Request.cs index 9581fba8c..f60f097ad 100644 --- a/src/ApiService/ApiService/onefuzzlib/Request.cs +++ b/src/ApiService/ApiService/onefuzzlib/Request.cs @@ -31,14 +31,35 @@ public class RequestHandling : IRequestHandling { throw new ArgumentOutOfRangeException($"status code {statusCode} - {statusNum} is not in the expected range [400; 599]"); } - public static async Async.Task> ParseRequest(HttpRequestData req) { + public static async Async.Task> ParseRequest(HttpRequestData req) + where T : BaseRequest { Exception? exception = null; try { var t = await req.ReadFromJsonAsync(); if (t != null) { + + // ExtensionData is used here to detect if there are any unknown + // properties set: + if (t.ExtensionData != null) { + var errors = new List(); + foreach (var (name, value) in t.ExtensionData) { + // allow additional properties if they are null, + // otherwise produce an error + if (value.ValueKind != JsonValueKind.Null) { + errors.Add($"Unexpected property: \"{name}\""); + } + } + + if (errors.Any()) { + return new Error( + Code: ErrorCode.INVALID_REQUEST, + Errors: errors.ToArray()); + } + } + var validationContext = new ValidationContext(t); var validationResults = new List(); - if (Validator.TryValidateObject(t, validationContext, validationResults, true)) { + if (Validator.TryValidateObject(t, validationContext, validationResults, validateAllProperties: true)) { return OneFuzzResult.Ok(t); } else { return new Error( @@ -48,8 +69,7 @@ public class RequestHandling : IRequestHandling { } else { return OneFuzzResult.Error( ErrorCode.INVALID_REQUEST, - $"Failed to deserialize message into type: {typeof(T)} - null" - ); + $"Failed to deserialize message into type: {typeof(T)} - null"); } } catch (Exception e) { exception = e; diff --git a/src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs b/src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs index 03e17dfce..9a0aee462 100644 --- a/src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs +++ b/src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs @@ -89,15 +89,13 @@ class OnefuzzNamingPolicy : JsonNamingPolicy { public class EntityConverter { private readonly ConcurrentDictionary _cache; - private static readonly JsonSerializerOptions _options; - - static EntityConverter() { - _options = new JsonSerializerOptions() { - PropertyNamingPolicy = new OnefuzzNamingPolicy(), - }; - _options.Converters.Add(new CustomEnumConverterFactory()); - _options.Converters.Add(new PolymorphicConverterFactory()); - } + private static readonly JsonSerializerOptions _options = new() { + PropertyNamingPolicy = new OnefuzzNamingPolicy(), + Converters = { + new CustomEnumConverterFactory(), + new PolymorphicConverterFactory(), + } + }; public EntityConverter() { _cache = new ConcurrentDictionary(); diff --git a/src/ApiService/IntegrationTests/TasksTests.cs b/src/ApiService/IntegrationTests/TasksTests.cs new file mode 100644 index 000000000..186ffd1ab --- /dev/null +++ b/src/ApiService/IntegrationTests/TasksTests.cs @@ -0,0 +1,71 @@ +using System; +using System.Net; +using System.Text.Json; +using System.Text.Json.Nodes; +using IntegrationTests.Fakes; +using Microsoft.OneFuzz.Service; +using Microsoft.OneFuzz.Service.Functions; +using Microsoft.OneFuzz.Service.OneFuzzLib.Orm; +using Xunit; +using Xunit.Abstractions; +using Async = System.Threading.Tasks; + +namespace IntegrationTests.Functions; + +[Trait("Category", "Live")] +public class AzureStorageTasksTest : TasksTestBase { + public AzureStorageTasksTest(ITestOutputHelper output) + : base(output, Integration.AzureStorage.FromEnvironment()) { } +} + +public class AzuriteTasksTest : TasksTestBase { + public AzuriteTasksTest(ITestOutputHelper output) + : base(output, new Integration.AzuriteStorage()) { } +} + +public abstract class TasksTestBase : FunctionTestBase { + public TasksTestBase(ITestOutputHelper output, IStorage storage) + : base(output, storage) { } + + [Fact] + public async Async.Task SpecifyingVmIsNotPermitted() { + var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); + var func = new Tasks(Logger, auth, Context); + + var req = new TaskCreate( + Guid.NewGuid(), + null, + new TaskDetails(TaskType.DotnetCoverage, 100), + new TaskPool(1, PoolName.Parse("pool"))); + + // the 'vm' property used to be permitted but is no longer, add it: + var serialized = (JsonObject?)JsonSerializer.SerializeToNode(req, EntityConverter.GetJsonSerializerOptions()); + serialized!["vm"] = new JsonObject { { "fake", 1 } }; + var testData = new TestHttpRequestData("POST", new BinaryData(JsonSerializer.SerializeToUtf8Bytes(serialized, EntityConverter.GetJsonSerializerOptions()))); + var result = await func.Run(testData); + Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); + var err = BodyAs(result); + Assert.Equal(new[] { "Unexpected property: \"vm\"" }, err.Errors); + } + + [Fact] + public async Async.Task PoolIsRequired() { + var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); + var func = new Tasks(Logger, auth, Context); + + // override the found user credentials - need these to check for admin + var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: Guid.NewGuid(), "upn"); + Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); + + var req = new TaskCreate( + Guid.NewGuid(), + null, + new TaskDetails(TaskType.DotnetCoverage, 100), + null! /* <- here */); + + var result = await func.Run(TestHttpRequestData.FromJson("POST", req)); + Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); + var err = BodyAs(result); + Assert.Equal(new[] { "The Pool field is required." }, err.Errors); + } +} diff --git a/src/ApiService/Tests/RequestsTests.cs b/src/ApiService/Tests/RequestsTests.cs index b79c44e98..dc7637e5c 100644 --- a/src/ApiService/Tests/RequestsTests.cs +++ b/src/ApiService/Tests/RequestsTests.cs @@ -1,8 +1,14 @@ -using System.IO; +using System; +using System.Collections.Generic; +using System.ComponentModel.DataAnnotations; +using System.IO; +using System.Linq; +using System.Reflection; using System.Text; using System.Text.Json; using System.Threading; using Azure.Core.Serialization; +using FluentAssertions; using Microsoft.OneFuzz.Service; using Microsoft.OneFuzz.Service.OneFuzzLib.Orm; using Xunit; @@ -40,6 +46,43 @@ public class RequestsTests { Assert.Equal(json, result); } + // Finds all non-nullable properties exposed on request objects (inheriting from BaseRequest). + // Note that at the moment we do not validate inner types since we are reusing some model types + // as request objects/DTOs, which we should stop doing. + public static IEnumerable NonNullableRequestProperties() { + var baseType = typeof(BaseRequest); + var asm = baseType.Assembly; + foreach (var requestType in asm.GetTypes().Where(t => t.IsAssignableTo(baseType))) { + if (requestType == baseType) { + continue; + } + + foreach (var property in requestType.GetProperties()) { + var nullabilityContext = new NullabilityInfoContext(); + var nullability = nullabilityContext.Create(property); + if (nullability.ReadState == NullabilityState.NotNull) { + yield return new object[] { requestType, property }; + } + } + } + } + + [Theory] + [MemberData(nameof(NonNullableRequestProperties))] + public void EnsureRequiredAttributesExistsOnNonNullableRequestProperties(Type requestType, PropertyInfo property) { + if (!property.IsDefined(typeof(RequiredAttribute))) { + // if not required it must have a default + + // find appropriate parameter + var param = requestType.GetConstructors().Single().GetParameters().Single(p => p.Name == property.Name); + Assert.True(param.HasDefaultValue, + "For request types, all non-nullable properties should either have a default value, or the [Required] attribute." + ); + } else { + // it is required, okay + } + } + [Fact] public void NodeEvent_WorkerEvent_Done() { // generated with: onefuzz-agent debug node_event worker_event done