diff --git a/src/ApiService/ApiService/AgentCanSchedule.cs b/src/ApiService/ApiService/AgentCanSchedule.cs index 577ba8d28..a93b48dc4 100644 --- a/src/ApiService/ApiService/AgentCanSchedule.cs +++ b/src/ApiService/ApiService/AgentCanSchedule.cs @@ -5,18 +5,24 @@ namespace Microsoft.OneFuzz.Service; public class AgentCanSchedule { private readonly ILogTracer _log; - + private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public AgentCanSchedule(ILogTracer log, IOnefuzzContext context) { + public AgentCanSchedule(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { _log = log; + _auth = auth; _context = context; } // [Function("AgentCanSchedule")] - public async Async.Task Run([HttpTrigger] HttpRequestData req) { + public Async.Task Run( + [HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route="agents/can_schedule")] + HttpRequestData req) + => _auth.CallIfAgent(req, Post); + + private async Async.Task Post(HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); - if (!request.IsOk || request.OkV == null) { + if (!request.IsOk) { return await _context.RequestHandling.NotOk(req, request.ErrorV, typeof(CanScheduleRequest).ToString()); } diff --git a/src/ApiService/ApiService/AgentCommands.cs b/src/ApiService/ApiService/AgentCommands.cs index b0b336f61..54979907d 100644 --- a/src/ApiService/ApiService/AgentCommands.cs +++ b/src/ApiService/ApiService/AgentCommands.cs @@ -5,26 +5,28 @@ namespace Microsoft.OneFuzz.Service; public class AgentCommands { private readonly ILogTracer _log; - + private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public AgentCommands(ILogTracer log, IOnefuzzContext context) { + public AgentCommands(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { _log = log; + _auth = auth; _context = context; } // [Function("AgentCommands")] - public async Async.Task Run([HttpTrigger("get", "delete")] HttpRequestData req) { - return req.Method switch { - "GET" => await Get(req), - "DELETE" => await Delete(req), + public Async.Task Run( + [HttpTrigger(AuthorizationLevel.Anonymous, "GET", "DELETE", Route="agents/commands")] + HttpRequestData req) + => _auth.CallIfAgent(req, r => r.Method switch { + "GET" => Get(req), + "DELETE" => Delete(req), _ => throw new NotImplementedException($"HTTP Method {req.Method} is not supported for this method") - }; - } + }); private async Async.Task Get(HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); - if (!request.IsOk || request.OkV == null) { + if (!request.IsOk) { return await _context.RequestHandling.NotOk(req, request.ErrorV, typeof(NodeCommandGet).ToString()); } var nodeCommand = request.OkV; diff --git a/src/ApiService/ApiService/AgentEvents.cs b/src/ApiService/ApiService/AgentEvents.cs index 3c9e1d6f6..384fb749d 100644 --- a/src/ApiService/ApiService/AgentEvents.cs +++ b/src/ApiService/ApiService/AgentEvents.cs @@ -7,20 +7,26 @@ namespace Microsoft.OneFuzz.Service; public class AgentEvents { private readonly ILogTracer _log; - + private readonly IEndpointAuthorization _auth; private readonly IOnefuzzContext _context; - public AgentEvents(ILogTracer log, IOnefuzzContext context) { + public AgentEvents(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { _log = log; + _auth = auth; _context = context; } private static readonly EntityConverter _entityConverter = new(); [Function("AgentEvents")] - public async Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "POST")] HttpRequestData req) { + public Async.Task Run( + [HttpTrigger(AuthorizationLevel.Anonymous, "POST", Route="agents/events")] + HttpRequestData req) + => _auth.CallIfAgent(req, Post); + + private async Async.Task Post(HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); - if (!request.IsOk || request.OkV == null) { + if (!request.IsOk) { return await _context.RequestHandling.NotOk(req, request.ErrorV, context: "node event"); } diff --git a/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs b/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs index 7fec5807d..cfdaad9b9 100644 --- a/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs +++ b/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs @@ -39,8 +39,8 @@ public class EndpointAuthorization : IEndpointAuthorization { if (!tokenResult.IsOk) { return await _context.RequestHandling.NotOk(req, tokenResult.ErrorV, "token verification", HttpStatusCode.Unauthorized); } - var token = tokenResult.OkV!; + var token = tokenResult.OkV; if (await IsUser(token)) { if (!allowUser) { return await Reject(req, token); diff --git a/src/ApiService/IntegrationTests/AgentCanScheduleTests.cs b/src/ApiService/IntegrationTests/AgentCanScheduleTests.cs new file mode 100644 index 000000000..4c9a8220c --- /dev/null +++ b/src/ApiService/IntegrationTests/AgentCanScheduleTests.cs @@ -0,0 +1,52 @@ +using System.Net; +using IntegrationTests.Fakes; +using Microsoft.OneFuzz.Service; +using Xunit; +using Xunit.Abstractions; +using Async = System.Threading.Tasks; + +namespace IntegrationTests; + +[Trait("Category", "Live")] +public class AzureStorageAgentCanScheduleTest : AgentCommandsTestsBase { + public AzureStorageAgentCanScheduleTest(ITestOutputHelper output) + : base(output, Integration.AzureStorage.FromEnvironment()) { } +} + +public class AzuriteAgentCanScheduleTest : AgentEventsTestsBase { + public AzuriteAgentCanScheduleTest(ITestOutputHelper output) + : base(output, new Integration.AzuriteStorage()) { } +} + +public abstract class AgentCanScheduleTestsBase : FunctionTestBase { + public AgentCanScheduleTestsBase(ITestOutputHelper output, IStorage storage) + : base(output, storage) { } + + + [Fact] + public async Async.Task Authorization_IsRequired() { + var auth = new TestEndpointAuthorization(RequestType.NoAuthorization, Logger, Context); + var func = new AgentCanSchedule(Logger, auth, Context); + + var result = await func.Run(TestHttpRequestData.Empty("POST")); + Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); + } + + [Fact] + public async Async.Task UserAuthorization_IsNotPermitted() { + var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); + var func = new AgentCanSchedule(Logger, auth, Context); + + var result = await func.Run(TestHttpRequestData.Empty("POST")); + Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); + } + + [Fact] + public async Async.Task AgentAuthorization_IsAccepted() { + var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); + var func = new AgentCanSchedule(Logger, auth, Context); + + var result = await func.Run(TestHttpRequestData.Empty("POST")); + Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); // BadRequest due to no body, not Unauthorized + } +} diff --git a/src/ApiService/IntegrationTests/AgentCommandsTests.cs b/src/ApiService/IntegrationTests/AgentCommandsTests.cs new file mode 100644 index 000000000..1c8ad94ac --- /dev/null +++ b/src/ApiService/IntegrationTests/AgentCommandsTests.cs @@ -0,0 +1,52 @@ +using System.Net; +using IntegrationTests.Fakes; +using Microsoft.OneFuzz.Service; +using Xunit; +using Xunit.Abstractions; +using Async = System.Threading.Tasks; + +namespace IntegrationTests; + +[Trait("Category", "Live")] +public class AzureStorageAgentCommandsTest : AgentCommandsTestsBase { + public AzureStorageAgentCommandsTest(ITestOutputHelper output) + : base(output, Integration.AzureStorage.FromEnvironment()) { } +} + +public class AzuriteAgentCommandsTest : AgentEventsTestsBase { + public AzuriteAgentCommandsTest(ITestOutputHelper output) + : base(output, new Integration.AzuriteStorage()) { } +} + +public abstract class AgentCommandsTestsBase : FunctionTestBase { + public AgentCommandsTestsBase(ITestOutputHelper output, IStorage storage) + : base(output, storage) { } + + + [Fact] + public async Async.Task Authorization_IsRequired() { + var auth = new TestEndpointAuthorization(RequestType.NoAuthorization, Logger, Context); + var func = new AgentCommands(Logger, auth, Context); + + var result = await func.Run(TestHttpRequestData.Empty("GET")); + Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); + } + + [Fact] + public async Async.Task UserAuthorization_IsNotPermitted() { + var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); + var func = new AgentCommands(Logger, auth, Context); + + var result = await func.Run(TestHttpRequestData.Empty("GET")); + Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); + } + + [Fact] + public async Async.Task AgentAuthorization_IsAccepted() { + var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); + var func = new AgentCommands(Logger, auth, Context); + + var result = await func.Run(TestHttpRequestData.Empty("GET")); + Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); // BadRequest due to no body, not Unauthorized + } +} diff --git a/src/ApiService/IntegrationTests/AgentEventsTests.cs b/src/ApiService/IntegrationTests/AgentEventsTests.cs index b80dd1ddc..76382a43b 100644 --- a/src/ApiService/IntegrationTests/AgentEventsTests.cs +++ b/src/ApiService/IntegrationTests/AgentEventsTests.cs @@ -32,9 +32,28 @@ public abstract class AgentEventsTestsBase : FunctionTestBase { readonly Guid poolId = Guid.NewGuid(); readonly string poolVersion = $"version-{Guid.NewGuid()}"; + [Fact] + public async Async.Task Authorization_IsRequired() { + var auth = new TestEndpointAuthorization(RequestType.NoAuthorization, Logger, Context); + var func = new AgentEvents(Logger, auth, Context); + + var result = await func.Run(TestHttpRequestData.Empty("POST")); + Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); + } + + [Fact] + public async Async.Task UserAuthorization_IsNotPermitted() { + var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); + var func = new AgentEvents(Logger, auth, Context); + + var result = await func.Run(TestHttpRequestData.Empty("POST")); + Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); + } + [Fact] public async Async.Task WorkerEventMustHaveDoneOrRunningSet() { - var func = new AgentEvents(Logger, Context); + var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); + var func = new AgentEvents(Logger, auth, Context); var data = new NodeStateEnvelope( MachineId: Guid.NewGuid(), @@ -53,7 +72,8 @@ public abstract class AgentEventsTestsBase : FunctionTestBase { new Task(jobId, taskId, TaskState.Running, Os.Linux, new TaskConfig(jobId, null, new TaskDetails(TaskType.Coverage, 100)))); - var func = new AgentEvents(Logger, Context); + var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); + var func = new AgentEvents(Logger, auth, Context); var data = new NodeStateEnvelope( MachineId: machineId, @@ -80,8 +100,8 @@ public abstract class AgentEventsTestsBase : FunctionTestBase { new Task(jobId, taskId, TaskState.Running, Os.Linux, new TaskConfig(jobId, null, new TaskDetails(TaskType.Coverage, 100)))); - - var func = new AgentEvents(Logger, Context); + var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); + var func = new AgentEvents(Logger, auth, Context); var data = new NodeStateEnvelope( MachineId: machineId, @@ -107,7 +127,8 @@ public abstract class AgentEventsTestsBase : FunctionTestBase { new Task(jobId, taskId, TaskState.Scheduled, Os.Linux, new TaskConfig(jobId, null, new TaskDetails(TaskType.Coverage, 100)))); - var func = new AgentEvents(Logger, Context); + var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); + var func = new AgentEvents(Logger, auth, Context); var data = new NodeStateEnvelope( MachineId: machineId, @@ -132,7 +153,8 @@ public abstract class AgentEventsTestsBase : FunctionTestBase { await Context.InsertAll( new Node(poolName, machineId, poolId, poolVersion)); - var func = new AgentEvents(Logger, Context); + var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); + var func = new AgentEvents(Logger, auth, Context); var data = new NodeStateEnvelope( MachineId: machineId, Event: new WorkerEvent(Running: new WorkerRunningEvent(taskId))); @@ -148,7 +170,8 @@ public abstract class AgentEventsTestsBase : FunctionTestBase { new Task(jobId, taskId, TaskState.Running, Os.Linux, new TaskConfig(jobId, null, new TaskDetails(TaskType.Coverage, 0)))); - var func = new AgentEvents(Logger, Context); + var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); + var func = new AgentEvents(Logger, auth, Context); var data = new NodeStateEnvelope( MachineId: machineId, Event: new WorkerEvent(Running: new WorkerRunningEvent(taskId))); @@ -165,7 +188,8 @@ public abstract class AgentEventsTestsBase : FunctionTestBase { new Task(jobId, taskId, TaskState.Running, Os.Linux, new TaskConfig(jobId, null, new TaskDetails(TaskType.Coverage, 0)))); - var func = new AgentEvents(Logger, Context); + var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); + var func = new AgentEvents(Logger, auth, Context); var data = new NodeStateEnvelope( MachineId: machineId, Event: new WorkerEvent(Running: new WorkerRunningEvent(taskId))); @@ -205,7 +229,8 @@ public abstract class AgentEventsTestsBase : FunctionTestBase { public async Async.Task NodeStateUpdate_ForMissingNode_IgnoresEvent() { // nothing present in storage - var func = new AgentEvents(Logger, Context); + var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); + var func = new AgentEvents(Logger, auth, Context); var data = new NodeStateEnvelope( MachineId: machineId, Event: new NodeStateUpdate(NodeState.Init)); @@ -220,7 +245,8 @@ public abstract class AgentEventsTestsBase : FunctionTestBase { await Context.InsertAll( new Node(poolName, machineId, poolId, poolVersion, State: NodeState.Init)); - var func = new AgentEvents(Logger, Context); + var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); + var func = new AgentEvents(Logger, auth, Context); var data = new NodeStateEnvelope( MachineId: machineId, Event: new NodeStateUpdate(NodeState.Ready)); @@ -237,7 +263,8 @@ public abstract class AgentEventsTestsBase : FunctionTestBase { await Context.InsertAll( new Node(poolName, machineId, poolId, poolVersion, ReimageRequested: true)); - var func = new AgentEvents(Logger, Context); + var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); + var func = new AgentEvents(Logger, auth, Context); var data = new NodeStateEnvelope( MachineId: machineId, Event: new NodeStateUpdate(NodeState.Free)); @@ -265,7 +292,8 @@ public abstract class AgentEventsTestsBase : FunctionTestBase { await Context.InsertAll( new Node(poolName, machineId, poolId, poolVersion, DeleteRequested: true)); - var func = new AgentEvents(Logger, Context); + var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); + var func = new AgentEvents(Logger, auth, Context); var data = new NodeStateEnvelope( MachineId: machineId, Event: new NodeStateUpdate(NodeState.Free));