Add maxPerPage to ORM (#3016)

* Add support for maxPerPage in OMR

* Fix small bug
This commit is contained in:
Teo Voinea
2023-04-12 16:37:56 -04:00
committed by GitHub
parent 41fa0a78bb
commit c105423d14
5 changed files with 45 additions and 9 deletions

View File

@ -31,7 +31,7 @@ public class AgentCommands {
}
var nodeCommand = request.OkV;
var message = await _context.NodeMessageOperations.GetMessage(nodeCommand.MachineId).FirstOrDefaultAsync();
var message = await _context.NodeMessageOperations.GetMessage(nodeCommand.MachineId);
if (message != null) {
var command = message.Message;
var messageId = message.MessageId;

View File

@ -46,7 +46,7 @@ public class Node {
var (tasks, messages) = await (
_context.NodeTasksOperations.GetByMachineId(machineId).ToListAsync().AsTask(),
_context.NodeMessageOperations.GetMessage(machineId).ToListAsync().AsTask());
_context.NodeMessageOperations.GetMessages(machineId).ToListAsync().AsTask());
var commands = messages.Select(m => m.Message).ToList();
return await RequestHandling.Ok(req, NodeToNodeSearchResult(node with { Tasks = tasks, Messages = commands }));

View File

@ -1,4 +1,5 @@
using ApiService.OneFuzzLib.Orm;
using System.Threading.Tasks;
using ApiService.OneFuzzLib.Orm;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
namespace Microsoft.OneFuzz.Service;
@ -14,7 +15,9 @@ public record NodeMessage(
};
public interface INodeMessageOperations : IOrm<NodeMessage> {
IAsyncEnumerable<NodeMessage> GetMessage(Guid machineId);
IAsyncEnumerable<NodeMessage> GetMessages(Guid machineId);
Async.Task<NodeMessage?> GetMessage(Guid machineId);
Async.Task ClearMessages(Guid machineId);
Async.Task SendMessage(Guid machineId, NodeCommand message, string? messageId = null);
@ -25,7 +28,7 @@ public class NodeMessageOperations : Orm<NodeMessage>, INodeMessageOperations {
public NodeMessageOperations(ILogTracer log, IOnefuzzContext context)
: base(log, context) { }
public IAsyncEnumerable<NodeMessage> GetMessage(Guid machineId)
public IAsyncEnumerable<NodeMessage> GetMessages(Guid machineId)
=> QueryAsync(Query.PartitionKey(machineId.ToString()));
public async Async.Task ClearMessages(Guid machineId) {
@ -45,4 +48,7 @@ public class NodeMessageOperations : Orm<NodeMessage>, INodeMessageOperations {
_logTracer.WithHttpStatus(r.ErrorV).Error($"failed to insert message with id: {messageId:Tag:MessageId} for machine id: {machineId:Tag:MachineId} message: {message:Tag:Message}");
}
}
public async Task<NodeMessage?> GetMessage(Guid machineId)
=> await QueryAsync(Query.PartitionKey(machineId.ToString()), maxPerPage: 1).FirstOrDefaultAsync();
}

View File

@ -12,7 +12,7 @@ using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
namespace ApiService.OneFuzzLib.Orm {
public interface IOrm<T> where T : EntityBase {
Task<TableClient> GetTableClient(string table, ResourceIdentifier? accountId = null);
IAsyncEnumerable<T> QueryAsync(string? filter = null);
IAsyncEnumerable<T> QueryAsync(string? filter = null, int? maxPerPage = null);
Task<T> GetEntityAsync(string partitionKey, string rowKey);
Task<ResultVoid<(HttpStatusCode Status, string Reason)>> Insert(T entity);
@ -49,14 +49,14 @@ namespace ApiService.OneFuzzLib.Orm {
_entityConverter = _context.EntityConverter;
}
public async IAsyncEnumerable<T> QueryAsync(string? filter = null) {
public async IAsyncEnumerable<T> QueryAsync(string? filter = null, int? maxPerPage = null) {
var tableClient = await GetTableClient(typeof(T).Name);
if (filter == "") {
filter = null;
}
await foreach (var x in tableClient.QueryAsync<TableEntity>(filter).Select(x => _entityConverter.ToRecord<T>(x))) {
await foreach (var x in tableClient.QueryAsync<TableEntity>(filter: filter, maxPerPage: maxPerPage).Select(x => _entityConverter.ToRecord<T>(x))) {
yield return x;
}
}

View File

@ -1,4 +1,6 @@
using System.Net;
using System;
using System.Net;
using FluentAssertions;
using IntegrationTests.Fakes;
using Microsoft.OneFuzz.Service;
using Microsoft.OneFuzz.Service.Functions;
@ -50,4 +52,32 @@ public abstract class AgentCommandsTestsBase : FunctionTestBase {
var result = await func.Run(TestHttpRequestData.Empty("GET"));
Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); // BadRequest due to no body, not Unauthorized
}
[Fact]
public async Async.Task AgentCommand_GetsCommand() {
var machineId = Guid.NewGuid();
var messageId = Guid.NewGuid().ToString();
var command = new NodeCommand {
Stop = new StopNodeCommand()
};
await Context.InsertAll(new[] {
new NodeMessage (
machineId,
messageId,
command
),
});
var commandRequest = new NodeCommandGet(machineId);
var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context);
var func = new AgentCommands(Logger, auth, Context);
var result = await func.Run(TestHttpRequestData.FromJson("GET", commandRequest));
Assert.Equal(HttpStatusCode.OK, result.StatusCode);
var pendingNodeCommand = BodyAs<PendingNodeCommand>(result);
pendingNodeCommand.Envelope.Should().NotBeNull();
pendingNodeCommand.Envelope?.Command.Should().BeEquivalentTo(command);
pendingNodeCommand.Envelope?.MessageId.Should().Be(messageId);
}
}