diff --git a/src/ApiService/ApiService/Functions/AgentCommands.cs b/src/ApiService/ApiService/Functions/AgentCommands.cs index b3a8678df..5c1d7721a 100644 --- a/src/ApiService/ApiService/Functions/AgentCommands.cs +++ b/src/ApiService/ApiService/Functions/AgentCommands.cs @@ -31,7 +31,7 @@ public class AgentCommands { } var nodeCommand = request.OkV; - var message = await _context.NodeMessageOperations.GetMessage(nodeCommand.MachineId).FirstOrDefaultAsync(); + var message = await _context.NodeMessageOperations.GetMessage(nodeCommand.MachineId); if (message != null) { var command = message.Message; var messageId = message.MessageId; diff --git a/src/ApiService/ApiService/Functions/Node.cs b/src/ApiService/ApiService/Functions/Node.cs index 8472d7473..2bdb93d56 100644 --- a/src/ApiService/ApiService/Functions/Node.cs +++ b/src/ApiService/ApiService/Functions/Node.cs @@ -46,7 +46,7 @@ public class Node { var (tasks, messages) = await ( _context.NodeTasksOperations.GetByMachineId(machineId).ToListAsync().AsTask(), - _context.NodeMessageOperations.GetMessage(machineId).ToListAsync().AsTask()); + _context.NodeMessageOperations.GetMessages(machineId).ToListAsync().AsTask()); var commands = messages.Select(m => m.Message).ToList(); return await RequestHandling.Ok(req, NodeToNodeSearchResult(node with { Tasks = tasks, Messages = commands })); diff --git a/src/ApiService/ApiService/onefuzzlib/NodeMessageOperations.cs b/src/ApiService/ApiService/onefuzzlib/NodeMessageOperations.cs index 48df6bb3c..9d31d41cb 100644 --- a/src/ApiService/ApiService/onefuzzlib/NodeMessageOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/NodeMessageOperations.cs @@ -1,4 +1,5 @@ -using ApiService.OneFuzzLib.Orm; +using System.Threading.Tasks; +using ApiService.OneFuzzLib.Orm; using Microsoft.OneFuzz.Service.OneFuzzLib.Orm; namespace Microsoft.OneFuzz.Service; @@ -14,7 +15,9 @@ public record NodeMessage( }; public interface INodeMessageOperations : IOrm { - IAsyncEnumerable GetMessage(Guid machineId); + IAsyncEnumerable GetMessages(Guid machineId); + + Async.Task GetMessage(Guid machineId); Async.Task ClearMessages(Guid machineId); Async.Task SendMessage(Guid machineId, NodeCommand message, string? messageId = null); @@ -25,7 +28,7 @@ public class NodeMessageOperations : Orm, INodeMessageOperations { public NodeMessageOperations(ILogTracer log, IOnefuzzContext context) : base(log, context) { } - public IAsyncEnumerable GetMessage(Guid machineId) + public IAsyncEnumerable GetMessages(Guid machineId) => QueryAsync(Query.PartitionKey(machineId.ToString())); public async Async.Task ClearMessages(Guid machineId) { @@ -45,4 +48,7 @@ public class NodeMessageOperations : Orm, INodeMessageOperations { _logTracer.WithHttpStatus(r.ErrorV).Error($"failed to insert message with id: {messageId:Tag:MessageId} for machine id: {machineId:Tag:MachineId} message: {message:Tag:Message}"); } } + + public async Task GetMessage(Guid machineId) + => await QueryAsync(Query.PartitionKey(machineId.ToString()), maxPerPage: 1).FirstOrDefaultAsync(); } diff --git a/src/ApiService/ApiService/onefuzzlib/orm/Orm.cs b/src/ApiService/ApiService/onefuzzlib/orm/Orm.cs index 578f629fa..b43c08e15 100644 --- a/src/ApiService/ApiService/onefuzzlib/orm/Orm.cs +++ b/src/ApiService/ApiService/onefuzzlib/orm/Orm.cs @@ -12,7 +12,7 @@ using Microsoft.OneFuzz.Service.OneFuzzLib.Orm; namespace ApiService.OneFuzzLib.Orm { public interface IOrm where T : EntityBase { Task GetTableClient(string table, ResourceIdentifier? accountId = null); - IAsyncEnumerable QueryAsync(string? filter = null); + IAsyncEnumerable QueryAsync(string? filter = null, int? maxPerPage = null); Task GetEntityAsync(string partitionKey, string rowKey); Task> Insert(T entity); @@ -49,14 +49,14 @@ namespace ApiService.OneFuzzLib.Orm { _entityConverter = _context.EntityConverter; } - public async IAsyncEnumerable QueryAsync(string? filter = null) { + public async IAsyncEnumerable QueryAsync(string? filter = null, int? maxPerPage = null) { var tableClient = await GetTableClient(typeof(T).Name); if (filter == "") { filter = null; } - await foreach (var x in tableClient.QueryAsync(filter).Select(x => _entityConverter.ToRecord(x))) { + await foreach (var x in tableClient.QueryAsync(filter: filter, maxPerPage: maxPerPage).Select(x => _entityConverter.ToRecord(x))) { yield return x; } } diff --git a/src/ApiService/IntegrationTests/AgentCommandsTests.cs b/src/ApiService/IntegrationTests/AgentCommandsTests.cs index 9ba5e480e..b18b32e97 100644 --- a/src/ApiService/IntegrationTests/AgentCommandsTests.cs +++ b/src/ApiService/IntegrationTests/AgentCommandsTests.cs @@ -1,4 +1,6 @@ -using System.Net; +using System; +using System.Net; +using FluentAssertions; using IntegrationTests.Fakes; using Microsoft.OneFuzz.Service; using Microsoft.OneFuzz.Service.Functions; @@ -50,4 +52,32 @@ public abstract class AgentCommandsTestsBase : FunctionTestBase { var result = await func.Run(TestHttpRequestData.Empty("GET")); Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); // BadRequest due to no body, not Unauthorized } + + [Fact] + public async Async.Task AgentCommand_GetsCommand() { + var machineId = Guid.NewGuid(); + var messageId = Guid.NewGuid().ToString(); + var command = new NodeCommand { + Stop = new StopNodeCommand() + }; + await Context.InsertAll(new[] { + new NodeMessage ( + machineId, + messageId, + command + ), + }); + + var commandRequest = new NodeCommandGet(machineId); + var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); + var func = new AgentCommands(Logger, auth, Context); + + var result = await func.Run(TestHttpRequestData.FromJson("GET", commandRequest)); + Assert.Equal(HttpStatusCode.OK, result.StatusCode); + + var pendingNodeCommand = BodyAs(result); + pendingNodeCommand.Envelope.Should().NotBeNull(); + pendingNodeCommand.Envelope?.Command.Should().BeEquivalentTo(command); + pendingNodeCommand.Envelope?.MessageId.Should().Be(messageId); + } }