diff --git a/src/ApiService/ApiService/Download.cs b/src/ApiService/ApiService/Download.cs new file mode 100644 index 000000000..5f5516e9b --- /dev/null +++ b/src/ApiService/ApiService/Download.cs @@ -0,0 +1,53 @@ +using System.Web; +using Azure.Storage.Sas; +using Microsoft.Azure.Functions.Worker; +using Microsoft.Azure.Functions.Worker.Http; + +namespace Microsoft.OneFuzz.Service; + +public class Download { + private readonly IEndpointAuthorization _auth; + private readonly IOnefuzzContext _context; + + public Download(IEndpointAuthorization auth, IOnefuzzContext context) { + _auth = auth; + _context = context; + } + + // [Function("Download")] + public Async.Task Run([HttpTrigger("GET")] HttpRequestData req) + => _auth.CallIfUser(req, Get); + + private async Async.Task Get(HttpRequestData req) { + var query = HttpUtility.ParseQueryString(req.Url.Query); + + var container = query["container"]; + if (container is null) { + return await _context.RequestHandling.NotOk( + req, + new Error( + ErrorCode.INVALID_REQUEST, + new string[] { "'container' query parameter must be provided" }), + "download"); + } + + var filename = query["filename"]; + if (filename is null) { + return await _context.RequestHandling.NotOk( + req, + new Error( + ErrorCode.INVALID_REQUEST, + new string[] { "'filename' query parameter must be provided" }), + "download"); + } + + var sasUri = await _context.Containers.GetFileSasUrl( + new Container(container), + filename, + StorageType.Corpus, + BlobSasPermissions.Read, + TimeSpan.FromMinutes(5)); + + return RequestHandling.Redirect(req, sasUri); + } +} diff --git a/src/ApiService/ApiService/onefuzzlib/Containers.cs b/src/ApiService/ApiService/onefuzzlib/Containers.cs index a05a3d7c1..f28286e8f 100644 --- a/src/ApiService/ApiService/onefuzzlib/Containers.cs +++ b/src/ApiService/ApiService/onefuzzlib/Containers.cs @@ -13,7 +13,7 @@ public interface IContainers { public Async.Task FindContainer(Container container, StorageType storageType); - public Async.Task GetFileSasUrl(Container container, string name, StorageType storageType, BlobSasPermissions permissions, TimeSpan? duration = null); + public Async.Task GetFileSasUrl(Container container, string name, StorageType storageType, BlobSasPermissions permissions, TimeSpan? duration = null); public Async.Task SaveBlob(Container container, string v1, string v2, StorageType config); public Async.Task GetInstanceId(); @@ -110,14 +110,14 @@ public class Containers : IContainers { return new BlobServiceClient(accountUrl, storageKeyCredential); } - public async Async.Task GetFileSasUrl(Container container, string name, StorageType storageType, BlobSasPermissions permissions, TimeSpan? duration = null) { + public async Async.Task GetFileSasUrl(Container container, string name, StorageType storageType, BlobSasPermissions permissions, TimeSpan? duration = null) { var client = await FindContainer(container, storageType) ?? throw new Exception($"unable to find container: {container.ContainerName} - {storageType}"); var (startTime, endTime) = SasTimeWindow(duration ?? TimeSpan.FromDays(30)); var sasBuilder = new BlobSasBuilder(permissions, endTime) { StartsOn = startTime, - BlobContainerName = container.ContainerName, + BlobContainerName = _config.OneFuzzStoragePrefix + container.ContainerName, BlobName = name }; diff --git a/src/ApiService/ApiService/onefuzzlib/Request.cs b/src/ApiService/ApiService/onefuzzlib/Request.cs index be1eafe62..f8e0b4e9c 100644 --- a/src/ApiService/ApiService/onefuzzlib/Request.cs +++ b/src/ApiService/ApiService/onefuzzlib/Request.cs @@ -57,6 +57,13 @@ public class RequestHandling : IRequestHandling { ); } + public static HttpResponseData Redirect(HttpRequestData req, Uri uri) { + var resp = req.CreateResponse(); + resp.StatusCode = HttpStatusCode.Found; + resp.Headers.Add("Location", uri.ToString()); + return resp; + } + public async static Async.Task Ok(HttpRequestData req, IEnumerable response) { var resp = req.CreateResponse(); resp.StatusCode = HttpStatusCode.OK; @@ -75,4 +82,3 @@ public class RequestHandling : IRequestHandling { return await Ok(req, new BaseResponse[] { response }); } } - diff --git a/src/ApiService/Tests/Fakes/TestHttpRequestData.cs b/src/ApiService/Tests/Fakes/TestHttpRequestData.cs index bd32a6bf8..2880629c1 100644 --- a/src/ApiService/Tests/Fakes/TestHttpRequestData.cs +++ b/src/ApiService/Tests/Fakes/TestHttpRequestData.cs @@ -12,21 +12,21 @@ using Moq; namespace Tests.Fakes; sealed class TestHttpRequestData : HttpRequestData { - private static readonly ObjectSerializer Serializer = + private static readonly ObjectSerializer _serializer = // we must use our shared JsonSerializerOptions to be able to serialize & deserialize polymorphic types new JsonObjectSerializer(Microsoft.OneFuzz.Service.OneFuzzLib.Orm.EntityConverter.GetJsonSerializerOptions()); sealed class TestServices : IServiceProvider { sealed class TestOptions : IOptions { // WorkerOptions only has one setting: Serializer - public WorkerOptions Value => new() { Serializer = Serializer }; + public WorkerOptions Value => new() { Serializer = _serializer }; } - static readonly IOptions Options = new TestOptions(); + static readonly IOptions _options = new TestOptions(); public object? GetService(Type serviceType) { if (serviceType == typeof(IOptions)) { - return Options; + return _options; } return null; @@ -42,7 +42,7 @@ sealed class TestHttpRequestData : HttpRequestData { } public static TestHttpRequestData FromJson(string method, T obj) - => new(method, Serializer.Serialize(obj)); + => new(method, _serializer.Serialize(obj)); public static TestHttpRequestData Empty(string method) => new(method, new BinaryData(Array.Empty())); @@ -53,6 +53,7 @@ sealed class TestHttpRequestData : HttpRequestData { _body = body; } + private Uri _url = new("https://example.com/"); private readonly BinaryData _body; public override Stream Body => _body.ToStream(); @@ -61,12 +62,15 @@ sealed class TestHttpRequestData : HttpRequestData { public override IReadOnlyCollection Cookies => throw new NotImplementedException(); - public override Uri Url => throw new NotImplementedException(); public override IEnumerable Identities => throw new NotImplementedException(); public override string Method { get; } + public override Uri Url => _url; + + public void SetUrl(Uri url) => _url = url; + public override HttpResponseData CreateResponse() => new TestHttpResponseData(FunctionContext); } diff --git a/src/ApiService/Tests/Functions/DownloadTests.cs b/src/ApiService/Tests/Functions/DownloadTests.cs new file mode 100644 index 000000000..bc75ce447 --- /dev/null +++ b/src/ApiService/Tests/Functions/DownloadTests.cs @@ -0,0 +1,95 @@ + +using System; +using System.Net; +using System.Net.Http; +using Microsoft.OneFuzz.Service; +using Tests.Fakes; +using Xunit; +using Xunit.Abstractions; + +using Async = System.Threading.Tasks; + +namespace Tests.Functions; + +[Trait("Category", "Integration")] +public class AzureStorageDownloadTest : DownloadTestBase { + public AzureStorageDownloadTest(ITestOutputHelper output) + : base(output, Integration.AzureStorage.FromEnvironment()) { } +} + +public class AzuriteDownloadTest : DownloadTestBase { + public AzuriteDownloadTest(ITestOutputHelper output) + : base(output, new Integration.AzuriteStorage()) { } +} + +public abstract class DownloadTestBase : FunctionTestBase { + public DownloadTestBase(ITestOutputHelper output, IStorage storage) + : base(output, storage) { } + + [Fact] + public async Async.Task Download_WithoutAuthorization_IsRejected() { + var auth = new TestEndpointAuthorization(RequestType.NoAuthorization, Logger, Context); + var func = new Download(auth, Context); + + var result = await func.Run(TestHttpRequestData.Empty("GET")); + Assert.Equal(HttpStatusCode.Unauthorized, result.StatusCode); + + var err = BodyAs(result); + Assert.Equal(ErrorCode.UNAUTHORIZED, err.Code); + } + + [Fact] + public async Async.Task Download_WithoutContainer_IsRejected() { + var req = TestHttpRequestData.Empty("GET"); + var url = new UriBuilder(req.Url) { Query = "filename=xxx" }.Uri; + req.SetUrl(url); + + var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); + var func = new Download(auth, Context); + var result = await func.Run(req); + Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); + + var err = BodyAs(result); + Assert.Equal(ErrorCode.INVALID_REQUEST, err.Code); + } + + [Fact] + public async Async.Task Download_WithoutFilename_IsRejected() { + var req = TestHttpRequestData.Empty("GET"); + var url = new UriBuilder(req.Url) { Query = "container=xxx" }.Uri; + req.SetUrl(url); + + var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); + var func = new Download(auth, Context); + + var result = await func.Run(req); + Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); + + var err = BodyAs(result); + Assert.Equal(ErrorCode.INVALID_REQUEST, err.Code); + } + + [Fact] + public async Async.Task Download_RedirectsToResult_WithLocationHeader() { + // set up a file to download + var container = GetContainerClient("xxx"); + await container.CreateAsync(); + await container.UploadBlobAsync("yyy", new BinaryData("content")); + + var req = TestHttpRequestData.Empty("GET"); + var url = new UriBuilder(req.Url) { Query = "container=xxx&filename=yyy" }.Uri; + req.SetUrl(url); + + var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); + var func = new Download(auth, Context); + + var result = await func.Run(req); + Assert.Equal(HttpStatusCode.Found, result.StatusCode); + var location = Assert.Single(result.Headers.GetValues("Location")); + + // check that the SAS URI works + using var client = new HttpClient(); + var blobContent = await client.GetStringAsync(location); + Assert.Equal("content", blobContent); + } +}