Enforce that there no extra properties in request JSON, and that non-null properties are [Required] (#2328)

Closes #2314 via two fixes, one for additional properties and one for missing properties:

- Make all request types inherit from `BaseRequest` which has an `ExtensionData` property, and ensure that it is empty in `ParseRequest`.
- Add `[Required]` attribute to non-nullable properties that do not have defaults, and add a test that ensures we have this attribute where necessary.
This commit is contained in:
George Pollard
2022-09-02 13:59:24 +12:00
committed by GitHub
parent 52ba57bf0d
commit c54db04083
9 changed files with 294 additions and 103 deletions

View File

@ -23,7 +23,7 @@ public class Jobs {
});
private async Task<HttpResponseData> Post(HttpRequestData req) {
var request = await RequestHandling.ParseRequest<JobConfig>(req);
var request = await RequestHandling.ParseRequest<JobCreate>(req);
if (!request.IsOk) {
return await _context.RequestHandling.NotOk(req, request.ErrorV, "jobs create");
}
@ -33,10 +33,18 @@ public class Jobs {
return await _context.RequestHandling.NotOk(req, userInfo.ErrorV, "jobs create");
}
var create = request.OkV;
var cfg = new JobConfig(
Build: create.Build,
Duration: create.Duration,
Logs: create.Logs,
Name: create.Name,
Project: create.Project);
var job = new Job(
JobId: Guid.NewGuid(),
State: JobState.Init,
Config: request.OkV) {
Config: cfg) {
UserInfo = userInfo.OkV,
};

View File

@ -51,7 +51,7 @@ public class ReproVmss {
private async Async.Task<HttpResponseData> Post(HttpRequestData req) {
var request = await RequestHandling.ParseRequest<ReproConfig>(req);
var request = await RequestHandling.ParseRequest<ReproCreate>(req);
if (!request.IsOk) {
return await _context.RequestHandling.NotOk(
req,
@ -67,7 +67,13 @@ public class ReproVmss {
"repro_vm create");
}
var vm = await _context.ReproOperations.Create(request.OkV, userInfo.OkV);
var create = request.OkV;
var cfg = new ReproConfig(
Container: create.Container,
Path: create.Path,
Duration: create.Duration);
var vm = await _context.ReproOperations.Create(cfg, userInfo.OkV);
if (!vm.IsOk) {
return await _context.RequestHandling.NotOk(
req,

View File

@ -72,7 +72,7 @@ public class Tasks {
private async Async.Task<HttpResponseData> Post(HttpRequestData req) {
var request = await RequestHandling.ParseRequest<TaskConfig>(req);
var request = await RequestHandling.ParseRequest<TaskCreate>(req);
if (!request.IsOk) {
return await _context.RequestHandling.NotOk(
req,
@ -85,7 +85,19 @@ public class Tasks {
return await _context.RequestHandling.NotOk(req, userInfo.ErrorV, "task create");
}
var checkConfig = await _context.Config.CheckConfig(request.OkV);
var create = request.OkV;
var cfg = new TaskConfig(
JobId: create.JobId,
PrereqTasks: create.PrereqTasks,
Task: create.Task,
Vm: null,
Pool: create.Pool,
Containers: create.Containers,
Tags: create.Tags,
Debug: create.Debug,
Colocate: create.Colocate);
var checkConfig = await _context.Config.CheckConfig(cfg);
if (!checkConfig.IsOk) {
return await _context.RequestHandling.NotOk(
req,
@ -99,23 +111,23 @@ public class Tasks {
return response;
}
var job = await _context.JobOperations.Get(request.OkV.JobId);
var job = await _context.JobOperations.Get(cfg.JobId);
if (job == null) {
return await _context.RequestHandling.NotOk(
req,
new Error(ErrorCode.INVALID_REQUEST, new[] { "unable to find job" }),
request.OkV.JobId.ToString());
cfg.JobId.ToString());
}
if (job.State != JobState.Enabled && job.State != JobState.Init) {
return await _context.RequestHandling.NotOk(
req,
new Error(ErrorCode.UNABLE_TO_ADD_TASK_TO_JOB, new[] { $"unable to add a job in state {job.State}" }),
request.OkV.JobId.ToString());
cfg.JobId.ToString());
}
if (request.OkV.PrereqTasks != null) {
foreach (var taskId in request.OkV.PrereqTasks) {
if (cfg.PrereqTasks != null) {
foreach (var taskId in cfg.PrereqTasks) {
var prereq = await _context.TaskOperations.GetByTaskId(taskId);
if (prereq == null) {
@ -127,7 +139,7 @@ public class Tasks {
}
}
var task = await _context.TaskOperations.Create(request.OkV, request.OkV.JobId, userInfo.OkV);
var task = await _context.TaskOperations.Create(cfg, cfg.JobId, userInfo.OkV);
if (!task.IsOk) {
return await _context.RequestHandling.NotOk(

View File

@ -229,20 +229,20 @@ public record TaskConfig(
Dictionary<string, string>? Tags = null,
List<TaskDebugFlag>? Debug = null,
bool? Colocate = null
);
);
public record TaskEventSummary(
DateTimeOffset? Timestamp,
string EventData,
string EventType
);
);
public record NodeAssignment(
Guid NodeId,
Guid? ScalesetId,
NodeTaskState State
);
);
public record Task(

View File

@ -1,30 +1,34 @@
using System.ComponentModel.DataAnnotations;
using System.Text.Json;
using System.Text.Json.Serialization;
namespace Microsoft.OneFuzz.Service;
public record BaseRequest();
public record BaseRequest {
[JsonExtensionData]
public Dictionary<string, JsonElement>? ExtensionData { get; set; }
};
public record CanScheduleRequest(
Guid MachineId,
Guid TaskId
[property: Required] Guid MachineId,
[property: Required] Guid TaskId
) : BaseRequest;
public record NodeCommandGet(
Guid MachineId
[property: Required] Guid MachineId
) : BaseRequest;
public record NodeCommandDelete(
Guid MachineId,
string MessageId
[property: Required] Guid MachineId,
[property: Required] string MessageId
) : BaseRequest;
public record NodeGet(
Guid MachineId
[property: Required] Guid MachineId
) : BaseRequest;
public record NodeUpdate(
Guid MachineId,
[property: Required] Guid MachineId,
bool? DebugKeepNode
) : BaseRequest;
@ -36,8 +40,8 @@ public record NodeSearch(
) : BaseRequest;
public record NodeStateEnvelope(
NodeEventBase Event,
Guid MachineId
[property: Required] NodeEventBase Event,
[property: Required] Guid MachineId
) : BaseRequest;
// either NodeEvent or WorkerEvent
@ -59,16 +63,16 @@ public record WorkerEvent(
) : NodeEventBase;
public record WorkerRunningEvent(
Guid TaskId);
[property: Required] Guid TaskId);
public record WorkerDoneEvent(
Guid TaskId,
ExitStatus ExitStatus,
string Stderr,
string Stdout);
[property: Required] Guid TaskId,
[property: Required] ExitStatus ExitStatus,
[property: Required] string Stderr,
[property: Required] string Stdout);
public record NodeStateUpdate(
NodeState State,
[property: Required] NodeState State,
[property: JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
NodeStateData? Data = null
) : NodeEventBase;
@ -78,7 +82,7 @@ public record NodeStateUpdate(
public abstract record NodeStateData;
public record NodeSettingUpEventData(
List<Guid> Tasks
[property: Required] List<Guid> Tasks
) : NodeStateData;
public record NodeDoneEventData(
@ -101,23 +105,23 @@ public record ExitStatus(
bool Success);
public record ContainerGet(
Container Name
[property: Required] Container Name
) : BaseRequest;
public record ContainerCreate(
Container Name,
[property: Required] Container Name,
IDictionary<string, string>? Metadata = null
) : BaseRequest;
public record ContainerDelete(
Container Name,
[property: Required] Container Name,
IDictionary<string, string>? Metadata = null
) : BaseRequest;
public record NotificationCreate(
Container Container,
bool ReplaceExisting,
NotificationTemplate Config
[property: Required] Container Container,
[property: Required] bool ReplaceExisting,
[property: Required] NotificationTemplate Config
) : BaseRequest;
public record NotificationSearch(
@ -125,58 +129,75 @@ public record NotificationSearch(
) : BaseRequest;
public record NotificationGet(
Guid NotificationId
[property: Required] Guid NotificationId
) : BaseRequest;
public record JobGet(
Guid JobId
);
[property: Required] Guid JobId
) : BaseRequest;
public record JobCreate(
[property: Required] string Project,
[property: Required] string Name,
[property: Required] string Build,
[property: Required] long Duration,
string? Logs
) : BaseRequest;
public record JobSearch(
Guid? JobId = null,
List<JobState>? State = null,
List<TaskState>? TaskState = null,
bool? WithTasks = null
);
) : BaseRequest;
public record NodeAddSshKeyPost(Guid MachineId, string PublicKey);
public record NodeAddSshKeyPost(
[property: Required] Guid MachineId,
[property: Required] string PublicKey
) : BaseRequest;
public record ReproGet(Guid? VmId);
public record ReproGet(Guid? VmId) : BaseRequest;
public record ReproCreate(
[property: Required] Container Container,
[property: Required] string Path,
[property: Required] long Duration
) : BaseRequest;
public record ProxyGet(
Guid? ScalesetId,
Guid? MachineId,
int? DstPort);
int? DstPort
) : BaseRequest;
public record ProxyCreate(
Guid ScalesetId,
Guid MachineId,
int DstPort,
int Duration
);
[property: Required] Guid ScalesetId,
[property: Required] Guid MachineId,
[property: Required] int DstPort,
[property: Required] int Duration
) : BaseRequest;
public record ProxyDelete(
Guid ScalesetId,
Guid MachineId,
[property: Required] Guid ScalesetId,
[property: Required] Guid MachineId,
int? DstPort
);
) : BaseRequest;
public record ProxyReset(
string Region
);
[property: Required] string Region
) : BaseRequest;
public record ScalesetCreate(
PoolName PoolName,
string VmSku,
string Image,
[property: Required] PoolName PoolName,
[property: Required] string VmSku,
[property: Required] string Image,
string? Region,
[property: Range(1, long.MaxValue)]
long Size,
bool SpotInstances,
Dictionary<string, string> Tags,
[property: Range(1, long.MaxValue), Required] long Size,
[property: Required] bool SpotInstances,
[property: Required] Dictionary<string, string> Tags,
bool EphemeralOsDisks = false,
AutoScaleOptions? AutoScale = null
);
) : BaseRequest;
public record AutoScaleOptions(
[property: Range(0, long.MaxValue)] long Min,
@ -192,63 +213,75 @@ public record ScalesetSearch(
Guid? ScalesetId = null,
List<ScalesetState>? State = null,
bool IncludeAuth = false
);
) : BaseRequest;
public record ScalesetStop(
Guid ScalesetId,
bool Now
);
[property: Required] Guid ScalesetId,
[property: Required] bool Now
) : BaseRequest;
public record ScalesetUpdate(
Guid ScalesetId,
[property: Required] Guid ScalesetId,
[property: Range(1, long.MaxValue)]
long? Size
);
) : BaseRequest;
public record TaskGet(Guid TaskId);
public record TaskGet(
[property: Required] Guid TaskId
) : BaseRequest;
public record TaskCreate(
[property: Required] Guid JobId,
List<Guid>? PrereqTasks,
[property: Required] TaskDetails Task,
[property: Required] TaskPool Pool,
List<TaskContainers>? Containers = null,
Dictionary<string, string>? Tags = null,
List<TaskDebugFlag>? Debug = null,
bool? Colocate = null
) : BaseRequest;
public record TaskSearch(
Guid? JobId,
Guid? TaskId,
List<TaskState> State);
[property: Required] List<TaskState> State) : BaseRequest;
public record PoolSearch(
Guid? PoolId = null,
PoolName? Name = null,
List<PoolState>? State = null
);
) : BaseRequest;
public record PoolStop(
PoolName Name,
bool Now
);
[property: Required] PoolName Name,
[property: Required] bool Now
) : BaseRequest;
public record PoolCreate(
PoolName Name,
Os Os,
Architecture Arch,
bool Managed,
[property: Required] PoolName Name,
[property: Required] Os Os,
[property: Required] Architecture Arch,
[property: Required] bool Managed,
Guid? ClientId = null
);
) : BaseRequest;
public record WebhookCreate(
string Name,
Uri Url,
List<EventType> EventTypes,
[property: Required] string Name,
[property: Required] Uri Url,
[property: Required] List<EventType> EventTypes,
string? SecretToken,
WebhookMessageFormat? MessageFormat
);
) : BaseRequest;
public record WebhookSearch(Guid? WebhookId) : BaseRequest;
public record WebhookSearch(Guid? WebhookId);
public record WebhookGet(Guid WebhookId);
public record WebhookGet([property: Required] Guid WebhookId) : BaseRequest;
public record WebhookUpdate(
Guid WebhookId,
[property: Required] Guid WebhookId,
string? Name,
Uri? Url,
List<EventType>? EventTypes,
string? SecretToken,
WebhookMessageFormat? MessageFormat
);
) : BaseRequest;

View File

@ -31,14 +31,35 @@ public class RequestHandling : IRequestHandling {
throw new ArgumentOutOfRangeException($"status code {statusCode} - {statusNum} is not in the expected range [400; 599]");
}
public static async Async.Task<OneFuzzResult<T>> ParseRequest<T>(HttpRequestData req) {
public static async Async.Task<OneFuzzResult<T>> ParseRequest<T>(HttpRequestData req)
where T : BaseRequest {
Exception? exception = null;
try {
var t = await req.ReadFromJsonAsync<T>();
if (t != null) {
// ExtensionData is used here to detect if there are any unknown
// properties set:
if (t.ExtensionData != null) {
var errors = new List<string>();
foreach (var (name, value) in t.ExtensionData) {
// allow additional properties if they are null,
// otherwise produce an error
if (value.ValueKind != JsonValueKind.Null) {
errors.Add($"Unexpected property: \"{name}\"");
}
}
if (errors.Any()) {
return new Error(
Code: ErrorCode.INVALID_REQUEST,
Errors: errors.ToArray());
}
}
var validationContext = new ValidationContext(t);
var validationResults = new List<ValidationResult>();
if (Validator.TryValidateObject(t, validationContext, validationResults, true)) {
if (Validator.TryValidateObject(t, validationContext, validationResults, validateAllProperties: true)) {
return OneFuzzResult.Ok(t);
} else {
return new Error(
@ -48,8 +69,7 @@ public class RequestHandling : IRequestHandling {
} else {
return OneFuzzResult<T>.Error(
ErrorCode.INVALID_REQUEST,
$"Failed to deserialize message into type: {typeof(T)} - null"
);
$"Failed to deserialize message into type: {typeof(T)} - null");
}
} catch (Exception e) {
exception = e;

View File

@ -89,15 +89,13 @@ class OnefuzzNamingPolicy : JsonNamingPolicy {
public class EntityConverter {
private readonly ConcurrentDictionary<Type, EntityInfo> _cache;
private static readonly JsonSerializerOptions _options;
static EntityConverter() {
_options = new JsonSerializerOptions() {
private static readonly JsonSerializerOptions _options = new() {
PropertyNamingPolicy = new OnefuzzNamingPolicy(),
};
_options.Converters.Add(new CustomEnumConverterFactory());
_options.Converters.Add(new PolymorphicConverterFactory());
Converters = {
new CustomEnumConverterFactory(),
new PolymorphicConverterFactory(),
}
};
public EntityConverter() {
_cache = new ConcurrentDictionary<Type, EntityInfo>();

View File

@ -0,0 +1,71 @@
using System;
using System.Net;
using System.Text.Json;
using System.Text.Json.Nodes;
using IntegrationTests.Fakes;
using Microsoft.OneFuzz.Service;
using Microsoft.OneFuzz.Service.Functions;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
using Xunit;
using Xunit.Abstractions;
using Async = System.Threading.Tasks;
namespace IntegrationTests.Functions;
[Trait("Category", "Live")]
public class AzureStorageTasksTest : TasksTestBase {
public AzureStorageTasksTest(ITestOutputHelper output)
: base(output, Integration.AzureStorage.FromEnvironment()) { }
}
public class AzuriteTasksTest : TasksTestBase {
public AzuriteTasksTest(ITestOutputHelper output)
: base(output, new Integration.AzuriteStorage()) { }
}
public abstract class TasksTestBase : FunctionTestBase {
public TasksTestBase(ITestOutputHelper output, IStorage storage)
: base(output, storage) { }
[Fact]
public async Async.Task SpecifyingVmIsNotPermitted() {
var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
var func = new Tasks(Logger, auth, Context);
var req = new TaskCreate(
Guid.NewGuid(),
null,
new TaskDetails(TaskType.DotnetCoverage, 100),
new TaskPool(1, PoolName.Parse("pool")));
// the 'vm' property used to be permitted but is no longer, add it:
var serialized = (JsonObject?)JsonSerializer.SerializeToNode(req, EntityConverter.GetJsonSerializerOptions());
serialized!["vm"] = new JsonObject { { "fake", 1 } };
var testData = new TestHttpRequestData("POST", new BinaryData(JsonSerializer.SerializeToUtf8Bytes(serialized, EntityConverter.GetJsonSerializerOptions())));
var result = await func.Run(testData);
Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode);
var err = BodyAs<Error>(result);
Assert.Equal(new[] { "Unexpected property: \"vm\"" }, err.Errors);
}
[Fact]
public async Async.Task PoolIsRequired() {
var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
var func = new Tasks(Logger, auth, Context);
// override the found user credentials - need these to check for admin
var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: Guid.NewGuid(), "upn");
Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult<UserInfo>.Ok(userInfo));
var req = new TaskCreate(
Guid.NewGuid(),
null,
new TaskDetails(TaskType.DotnetCoverage, 100),
null! /* <- here */);
var result = await func.Run(TestHttpRequestData.FromJson("POST", req));
Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode);
var err = BodyAs<Error>(result);
Assert.Equal(new[] { "The Pool field is required." }, err.Errors);
}
}

View File

@ -1,8 +1,14 @@
using System.IO;
using System;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Text;
using System.Text.Json;
using System.Threading;
using Azure.Core.Serialization;
using FluentAssertions;
using Microsoft.OneFuzz.Service;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
using Xunit;
@ -40,6 +46,43 @@ public class RequestsTests {
Assert.Equal(json, result);
}
// Finds all non-nullable properties exposed on request objects (inheriting from BaseRequest).
// Note that at the moment we do not validate inner types since we are reusing some model types
// as request objects/DTOs, which we should stop doing.
public static IEnumerable<object[]> NonNullableRequestProperties() {
var baseType = typeof(BaseRequest);
var asm = baseType.Assembly;
foreach (var requestType in asm.GetTypes().Where(t => t.IsAssignableTo(baseType))) {
if (requestType == baseType) {
continue;
}
foreach (var property in requestType.GetProperties()) {
var nullabilityContext = new NullabilityInfoContext();
var nullability = nullabilityContext.Create(property);
if (nullability.ReadState == NullabilityState.NotNull) {
yield return new object[] { requestType, property };
}
}
}
}
[Theory]
[MemberData(nameof(NonNullableRequestProperties))]
public void EnsureRequiredAttributesExistsOnNonNullableRequestProperties(Type requestType, PropertyInfo property) {
if (!property.IsDefined(typeof(RequiredAttribute))) {
// if not required it must have a default
// find appropriate parameter
var param = requestType.GetConstructors().Single().GetParameters().Single(p => p.Name == property.Name);
Assert.True(param.HasDefaultValue,
"For request types, all non-nullable properties should either have a default value, or the [Required] attribute."
);
} else {
// it is required, okay
}
}
[Fact]
public void NodeEvent_WorkerEvent_Done() {
// generated with: onefuzz-agent debug node_event worker_event done