Add maxPerPage to ORM (#3016)

* Add support for maxPerPage in OMR

* Fix small bug
This commit is contained in:
Teo Voinea
2023-04-12 16:37:56 -04:00
committed by GitHub
parent 41fa0a78bb
commit c105423d14
5 changed files with 45 additions and 9 deletions

View File

@ -31,7 +31,7 @@ public class AgentCommands {
} }
var nodeCommand = request.OkV; var nodeCommand = request.OkV;
var message = await _context.NodeMessageOperations.GetMessage(nodeCommand.MachineId).FirstOrDefaultAsync(); var message = await _context.NodeMessageOperations.GetMessage(nodeCommand.MachineId);
if (message != null) { if (message != null) {
var command = message.Message; var command = message.Message;
var messageId = message.MessageId; var messageId = message.MessageId;

View File

@ -46,7 +46,7 @@ public class Node {
var (tasks, messages) = await ( var (tasks, messages) = await (
_context.NodeTasksOperations.GetByMachineId(machineId).ToListAsync().AsTask(), _context.NodeTasksOperations.GetByMachineId(machineId).ToListAsync().AsTask(),
_context.NodeMessageOperations.GetMessage(machineId).ToListAsync().AsTask()); _context.NodeMessageOperations.GetMessages(machineId).ToListAsync().AsTask());
var commands = messages.Select(m => m.Message).ToList(); var commands = messages.Select(m => m.Message).ToList();
return await RequestHandling.Ok(req, NodeToNodeSearchResult(node with { Tasks = tasks, Messages = commands })); return await RequestHandling.Ok(req, NodeToNodeSearchResult(node with { Tasks = tasks, Messages = commands }));

View File

@ -1,4 +1,5 @@
using ApiService.OneFuzzLib.Orm; using System.Threading.Tasks;
using ApiService.OneFuzzLib.Orm;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm; using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
namespace Microsoft.OneFuzz.Service; namespace Microsoft.OneFuzz.Service;
@ -14,7 +15,9 @@ public record NodeMessage(
}; };
public interface INodeMessageOperations : IOrm<NodeMessage> { public interface INodeMessageOperations : IOrm<NodeMessage> {
IAsyncEnumerable<NodeMessage> GetMessage(Guid machineId); IAsyncEnumerable<NodeMessage> GetMessages(Guid machineId);
Async.Task<NodeMessage?> GetMessage(Guid machineId);
Async.Task ClearMessages(Guid machineId); Async.Task ClearMessages(Guid machineId);
Async.Task SendMessage(Guid machineId, NodeCommand message, string? messageId = null); Async.Task SendMessage(Guid machineId, NodeCommand message, string? messageId = null);
@ -25,7 +28,7 @@ public class NodeMessageOperations : Orm<NodeMessage>, INodeMessageOperations {
public NodeMessageOperations(ILogTracer log, IOnefuzzContext context) public NodeMessageOperations(ILogTracer log, IOnefuzzContext context)
: base(log, context) { } : base(log, context) { }
public IAsyncEnumerable<NodeMessage> GetMessage(Guid machineId) public IAsyncEnumerable<NodeMessage> GetMessages(Guid machineId)
=> QueryAsync(Query.PartitionKey(machineId.ToString())); => QueryAsync(Query.PartitionKey(machineId.ToString()));
public async Async.Task ClearMessages(Guid machineId) { public async Async.Task ClearMessages(Guid machineId) {
@ -45,4 +48,7 @@ public class NodeMessageOperations : Orm<NodeMessage>, INodeMessageOperations {
_logTracer.WithHttpStatus(r.ErrorV).Error($"failed to insert message with id: {messageId:Tag:MessageId} for machine id: {machineId:Tag:MachineId} message: {message:Tag:Message}"); _logTracer.WithHttpStatus(r.ErrorV).Error($"failed to insert message with id: {messageId:Tag:MessageId} for machine id: {machineId:Tag:MachineId} message: {message:Tag:Message}");
} }
} }
public async Task<NodeMessage?> GetMessage(Guid machineId)
=> await QueryAsync(Query.PartitionKey(machineId.ToString()), maxPerPage: 1).FirstOrDefaultAsync();
} }

View File

@ -12,7 +12,7 @@ using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
namespace ApiService.OneFuzzLib.Orm { namespace ApiService.OneFuzzLib.Orm {
public interface IOrm<T> where T : EntityBase { public interface IOrm<T> where T : EntityBase {
Task<TableClient> GetTableClient(string table, ResourceIdentifier? accountId = null); Task<TableClient> GetTableClient(string table, ResourceIdentifier? accountId = null);
IAsyncEnumerable<T> QueryAsync(string? filter = null); IAsyncEnumerable<T> QueryAsync(string? filter = null, int? maxPerPage = null);
Task<T> GetEntityAsync(string partitionKey, string rowKey); Task<T> GetEntityAsync(string partitionKey, string rowKey);
Task<ResultVoid<(HttpStatusCode Status, string Reason)>> Insert(T entity); Task<ResultVoid<(HttpStatusCode Status, string Reason)>> Insert(T entity);
@ -49,14 +49,14 @@ namespace ApiService.OneFuzzLib.Orm {
_entityConverter = _context.EntityConverter; _entityConverter = _context.EntityConverter;
} }
public async IAsyncEnumerable<T> QueryAsync(string? filter = null) { public async IAsyncEnumerable<T> QueryAsync(string? filter = null, int? maxPerPage = null) {
var tableClient = await GetTableClient(typeof(T).Name); var tableClient = await GetTableClient(typeof(T).Name);
if (filter == "") { if (filter == "") {
filter = null; filter = null;
} }
await foreach (var x in tableClient.QueryAsync<TableEntity>(filter).Select(x => _entityConverter.ToRecord<T>(x))) { await foreach (var x in tableClient.QueryAsync<TableEntity>(filter: filter, maxPerPage: maxPerPage).Select(x => _entityConverter.ToRecord<T>(x))) {
yield return x; yield return x;
} }
} }

View File

@ -1,4 +1,6 @@
using System.Net; using System;
using System.Net;
using FluentAssertions;
using IntegrationTests.Fakes; using IntegrationTests.Fakes;
using Microsoft.OneFuzz.Service; using Microsoft.OneFuzz.Service;
using Microsoft.OneFuzz.Service.Functions; using Microsoft.OneFuzz.Service.Functions;
@ -50,4 +52,32 @@ public abstract class AgentCommandsTestsBase : FunctionTestBase {
var result = await func.Run(TestHttpRequestData.Empty("GET")); var result = await func.Run(TestHttpRequestData.Empty("GET"));
Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); // BadRequest due to no body, not Unauthorized Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); // BadRequest due to no body, not Unauthorized
} }
[Fact]
public async Async.Task AgentCommand_GetsCommand() {
var machineId = Guid.NewGuid();
var messageId = Guid.NewGuid().ToString();
var command = new NodeCommand {
Stop = new StopNodeCommand()
};
await Context.InsertAll(new[] {
new NodeMessage (
machineId,
messageId,
command
),
});
var commandRequest = new NodeCommandGet(machineId);
var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context);
var func = new AgentCommands(Logger, auth, Context);
var result = await func.Run(TestHttpRequestData.FromJson("GET", commandRequest));
Assert.Equal(HttpStatusCode.OK, result.StatusCode);
var pendingNodeCommand = BodyAs<PendingNodeCommand>(result);
pendingNodeCommand.Envelope.Should().NotBeNull();
pendingNodeCommand.Envelope?.Command.Should().BeEquivalentTo(command);
pendingNodeCommand.Envelope?.MessageId.Should().Be(messageId);
}
} }