mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-12 01:58:18 +00:00
Enable C# functions in code & fill out missing functionality (#2084)
C# HTTP functions won’t take effect by default so it is safe to enable them in code. Also implement required authentication code.
This commit is contained in:
@ -17,8 +17,8 @@ public class AgentEvents {
|
||||
|
||||
private static readonly EntityConverter _entityConverter = new();
|
||||
|
||||
// [Function("AgentEvents")]
|
||||
public async Async.Task<HttpResponseData> Run([HttpTrigger("post")] HttpRequestData req) {
|
||||
[Function("AgentEvents")]
|
||||
public async Async.Task<HttpResponseData> Run([HttpTrigger(AuthorizationLevel.Anonymous, "POST")] HttpRequestData req) {
|
||||
var request = await RequestHandling.ParseRequest<NodeStateEnvelope>(req);
|
||||
if (!request.IsOk || request.OkV == null) {
|
||||
return await _context.RequestHandling.NotOk(req, request.ErrorV, context: "node event");
|
||||
|
@ -14,8 +14,8 @@ public class Download {
|
||||
_context = context;
|
||||
}
|
||||
|
||||
// [Function("Download")]
|
||||
public Async.Task<HttpResponseData> Run([HttpTrigger("GET")] HttpRequestData req)
|
||||
[Function("Download")]
|
||||
public Async.Task<HttpResponseData> Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET")] HttpRequestData req)
|
||||
=> _auth.CallIfUser(req, Get);
|
||||
|
||||
private async Async.Task<HttpResponseData> Get(HttpRequestData req) {
|
||||
|
44
src/ApiService/ApiService/GroupMembershipChecker.cs
Normal file
44
src/ApiService/ApiService/GroupMembershipChecker.cs
Normal file
@ -0,0 +1,44 @@
|
||||
using System.Net.Http;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.OneFuzz.Service;
|
||||
|
||||
abstract class GroupMembershipChecker {
|
||||
protected abstract Async.Task<IEnumerable<Guid>> GetGroups(Guid memberId);
|
||||
|
||||
public async Async.Task<bool> IsMember(IEnumerable<Guid> groupIds, Guid memberId) {
|
||||
if (groupIds.Contains(memberId)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
var memberGroups = await GetGroups(memberId);
|
||||
if (groupIds.Any(memberGroups.Contains)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
class AzureADGroupMembership : GroupMembershipChecker {
|
||||
private readonly ICreds _creds;
|
||||
public AzureADGroupMembership(ICreds creds) => _creds = creds;
|
||||
protected override async Task<IEnumerable<Guid>> GetGroups(Guid memberId) =>
|
||||
await _creds.QueryMicrosoftGraph<List<Guid>>(HttpMethod.Get, $"users/{memberId}/transitiveMemberOf");
|
||||
}
|
||||
|
||||
class StaticGroupMembership : GroupMembershipChecker {
|
||||
private readonly Dictionary<Guid, List<Guid>> _memberships;
|
||||
public StaticGroupMembership(IDictionary<Guid, Guid[]> memberships) {
|
||||
_memberships = memberships.ToDictionary(kvp => kvp.Key, kvp => kvp.Value.ToList());
|
||||
}
|
||||
|
||||
protected override Task<IEnumerable<Guid>> GetGroups(Guid memberId) {
|
||||
var result = Enumerable.Empty<Guid>();
|
||||
if (_memberships.TryGetValue(memberId, out var found)) {
|
||||
result = found;
|
||||
}
|
||||
|
||||
return Async.Task.FromResult(result);
|
||||
}
|
||||
}
|
@ -31,12 +31,13 @@ public class Info {
|
||||
var asm = Assembly.GetExecutingAssembly();
|
||||
var gitVersion = ReadResource(asm, "ApiService.onefuzzlib.git.version");
|
||||
var buildId = ReadResource(asm, "ApiService.onefuzzlib.build.id");
|
||||
var versionString = asm.GetCustomAttribute<AssemblyInformationalVersionAttribute>()?.InformationalVersion;
|
||||
|
||||
return new InfoResponse(
|
||||
ResourceGroup: resourceGroup,
|
||||
Subscription: subscription,
|
||||
Region: region,
|
||||
Versions: new Dictionary<string, InfoVersion> { { "onefuzz", new(gitVersion, buildId, config.OneFuzzVersion) } },
|
||||
Versions: new Dictionary<string, InfoVersion> { { "onefuzz", new(gitVersion, buildId, versionString ?? "") } },
|
||||
InstanceId: await _context.Containers.GetInstanceId(),
|
||||
InsightsAppid: config.ApplicationInsightsAppId,
|
||||
InsightsInstrumentationKey: config.ApplicationInsightsInstrumentationKey);
|
||||
@ -50,13 +51,13 @@ public class Info {
|
||||
}
|
||||
|
||||
using var sr = new StreamReader(r);
|
||||
return sr.ReadToEnd();
|
||||
return sr.ReadToEnd().Trim();
|
||||
}
|
||||
|
||||
private async Async.Task<HttpResponseData> GetResponse(HttpRequestData req)
|
||||
=> await RequestHandling.Ok(req, await _response.Value);
|
||||
|
||||
// [Function("Info")]
|
||||
public Async.Task<HttpResponseData> Run([HttpTrigger("GET")] HttpRequestData req)
|
||||
[Function("Info")]
|
||||
public Async.Task<HttpResponseData> Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET")] HttpRequestData req)
|
||||
=> _auth.CallIfUser(req, GetResponse);
|
||||
}
|
||||
|
@ -18,8 +18,8 @@ public class NodeFunction {
|
||||
|
||||
private static readonly EntityConverter _entityConverter = new();
|
||||
|
||||
// [Function("Node")
|
||||
public Async.Task<HttpResponseData> Run([HttpTrigger("GET", "PATCH", "POST", "DELETE")] HttpRequestData req) {
|
||||
[Function("Node")]
|
||||
public Async.Task<HttpResponseData> Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET", "PATCH", "POST", "DELETE")] HttpRequestData req) {
|
||||
return _auth.CallIfUser(req, r => r.Method switch {
|
||||
"GET" => Get(r),
|
||||
"PATCH" => Patch(r),
|
||||
|
@ -299,8 +299,8 @@ public record NetworkSecurityGroupConfig(
|
||||
}
|
||||
|
||||
public record ApiAccessRule(
|
||||
string[] Methods,
|
||||
Guid[] AllowedGroups
|
||||
IReadOnlyList<string> Methods,
|
||||
IReadOnlyList<Guid> AllowedGroups
|
||||
);
|
||||
|
||||
//# initial set of admins can only be set during deployment.
|
||||
|
@ -1,4 +1,6 @@
|
||||
using System.Text.Json;
|
||||
using System.Net.Http;
|
||||
using System.Net.Http.Json;
|
||||
using System.Threading.Tasks;
|
||||
using Azure.Core;
|
||||
using Azure.Identity;
|
||||
using Azure.ResourceManager;
|
||||
@ -24,18 +26,21 @@ public interface ICreds {
|
||||
public Async.Task<string> GetBaseRegion();
|
||||
|
||||
public Uri GetInstanceUrl();
|
||||
Guid GetScalesetPrincipalId();
|
||||
public Async.Task<Guid> GetScalesetPrincipalId();
|
||||
public Async.Task<T> QueryMicrosoftGraph<T>(HttpMethod method, string resource);
|
||||
}
|
||||
|
||||
public class Creds : ICreds {
|
||||
private readonly ArmClient _armClient;
|
||||
private readonly DefaultAzureCredential _azureCredential;
|
||||
private readonly IServiceConfig _config;
|
||||
private readonly IHttpClientFactory _httpClientFactory;
|
||||
|
||||
public ArmClient ArmClient => _armClient;
|
||||
|
||||
public Creds(IServiceConfig config) {
|
||||
public Creds(IServiceConfig config, IHttpClientFactory httpClientFactory) {
|
||||
_config = config;
|
||||
_httpClientFactory = httpClientFactory;
|
||||
_azureCredential = new DefaultAzureCredential();
|
||||
_armClient = new ArmClient(this.GetIdentity(), this.GetSubscription());
|
||||
}
|
||||
@ -88,11 +93,14 @@ public class Creds : ICreds {
|
||||
return new Uri($"https://{GetInstanceName()}.azurewebsites.net");
|
||||
}
|
||||
|
||||
public Guid GetScalesetPrincipalId() {
|
||||
var uid = ArmClient.GetGenericResource(
|
||||
new ResourceIdentifier(GetScalesetIdentityResourcePath())
|
||||
);
|
||||
var principalId = JsonSerializer.Deserialize<JsonDocument>(uid.Data.Properties.ToString())?.RootElement.GetProperty("principalId").GetString()!;
|
||||
public record ScaleSetIdentity(string principalId);
|
||||
|
||||
public async Async.Task<Guid> GetScalesetPrincipalId() {
|
||||
var path = GetScalesetIdentityResourcePath();
|
||||
var uid = ArmClient.GetGenericResource(new ResourceIdentifier(path));
|
||||
|
||||
var resource = await uid.GetAsync();
|
||||
var principalId = resource.Value.Data.Properties.ToObjectFromJson<ScaleSetIdentity>().principalId;
|
||||
return new Guid(principalId);
|
||||
}
|
||||
|
||||
@ -102,4 +110,44 @@ public class Creds : ICreds {
|
||||
|
||||
return $"{resourceGroupPath}/Microsoft.ManagedIdentity/userAssignedIdentities/{scalesetIdName}";
|
||||
}
|
||||
|
||||
|
||||
// https://docs.microsoft.com/en-us/graph/api/overview?view=graph-rest-1.0
|
||||
private static readonly Uri _graphResource = new("https://graph.microsoft.com");
|
||||
private static readonly Uri _graphResourceEndpoint = new("https://graph.microsoft.com/v1.0");
|
||||
|
||||
public async Task<T> QueryMicrosoftGraph<T>(HttpMethod method, string resource) {
|
||||
var cred = GetIdentity();
|
||||
|
||||
var scopes = new string[] { $"{_graphResource}/.default" };
|
||||
var accessToken = await cred.GetTokenAsync(new TokenRequestContext(scopes));
|
||||
|
||||
var uri = new Uri($"{_graphResourceEndpoint}/{resource}");
|
||||
using var httpClient = _httpClientFactory.CreateClient();
|
||||
using var response = await httpClient.SendAsync(new HttpRequestMessage {
|
||||
Headers = {
|
||||
{"Authorization", $"Bearer {accessToken.Token}"},
|
||||
{"Content-Type", "application/json"},
|
||||
},
|
||||
Method = method,
|
||||
RequestUri = uri,
|
||||
});
|
||||
|
||||
if (response.IsSuccessStatusCode) {
|
||||
var result = await response.Content.ReadFromJsonAsync<T>();
|
||||
if (result is null) {
|
||||
throw new GraphQueryException($"invalid data expected a json object: HTTP {response.StatusCode}");
|
||||
}
|
||||
|
||||
return result;
|
||||
} else {
|
||||
var errorText = await response.Content.ReadAsStringAsync();
|
||||
throw new GraphQueryException($"request did not succeed: HTTP {response.StatusCode} - {errorText}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class GraphQueryException : Exception {
|
||||
public GraphQueryException(string? message) : base(message) {
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,5 @@
|
||||
using System.Net;
|
||||
using System.Net.Http;
|
||||
using Microsoft.Azure.Functions.Worker.Http;
|
||||
|
||||
namespace Microsoft.OneFuzz.Service;
|
||||
@ -45,13 +46,12 @@ public class EndpointAuthorization : IEndpointAuthorization {
|
||||
return await Reject(req, token);
|
||||
}
|
||||
|
||||
var access = CheckAccess(req);
|
||||
var access = await CheckAccess(req);
|
||||
if (!access.IsOk) {
|
||||
return await _context.RequestHandling.NotOk(req, access.ErrorV, "access control", HttpStatusCode.Unauthorized);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (await IsAgent(token) && !allowAgent) {
|
||||
return await Reject(req, token);
|
||||
}
|
||||
@ -132,8 +132,54 @@ public class EndpointAuthorization : IEndpointAuthorization {
|
||||
}
|
||||
}
|
||||
|
||||
public OneFuzzResultVoid CheckAccess(HttpRequestData req) {
|
||||
throw new NotImplementedException();
|
||||
public async Async.Task<OneFuzzResultVoid> CheckAccess(HttpRequestData req) {
|
||||
var instanceConfig = await _context.ConfigOperations.Fetch();
|
||||
|
||||
var rules = GetRules(instanceConfig);
|
||||
if (rules is null) {
|
||||
return OneFuzzResultVoid.Ok;
|
||||
}
|
||||
|
||||
var path = req.Url.AbsolutePath;
|
||||
var rule = rules.GetMatchingRules(new HttpMethod(req.Method), path);
|
||||
if (rule is null) {
|
||||
return OneFuzzResultVoid.Ok;
|
||||
}
|
||||
|
||||
var memberId = Guid.Parse(req.Headers.GetValues("x-ms-client-principal-id").Single());
|
||||
try {
|
||||
var membershipChecker = CreateGroupMembershipChecker(instanceConfig);
|
||||
var allowed = await membershipChecker.IsMember(rule.AllowedGroupsIds, memberId);
|
||||
if (!allowed) {
|
||||
_log.Error($"unauthorized access: {memberId} is not authorized to access {path}");
|
||||
return new Error(
|
||||
Code: ErrorCode.UNAUTHORIZED,
|
||||
Errors: new string[] { "not approved to use this endpoint" });
|
||||
} else {
|
||||
return OneFuzzResultVoid.Ok;
|
||||
}
|
||||
} catch (Exception ex) {
|
||||
return new Error(
|
||||
Code: ErrorCode.UNAUTHORIZED,
|
||||
Errors: new string[] { "unable to interact with graph", ex.Message });
|
||||
}
|
||||
}
|
||||
|
||||
private GroupMembershipChecker CreateGroupMembershipChecker(InstanceConfig config) {
|
||||
if (config.GroupMembership is not null) {
|
||||
return new StaticGroupMembership(config.GroupMembership);
|
||||
}
|
||||
|
||||
return new AzureADGroupMembership(_context.Creds);
|
||||
}
|
||||
|
||||
private static RequestAccess? GetRules(InstanceConfig config) {
|
||||
var accessRules = config?.ApiAccessRules;
|
||||
if (accessRules is not null) {
|
||||
return RequestAccess.Build(accessRules);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
public async Async.Task<bool> IsAgent(UserInfo tokenData) {
|
||||
@ -143,7 +189,7 @@ public class EndpointAuthorization : IEndpointAuthorization {
|
||||
return true;
|
||||
}
|
||||
|
||||
var principalId = _context.Creds.GetScalesetPrincipalId();
|
||||
var principalId = await _context.Creds.GetScalesetPrincipalId();
|
||||
return principalId == tokenData.ObjectId;
|
||||
}
|
||||
|
||||
|
@ -18,6 +18,7 @@ public class ConfigOperations : Orm<InstanceConfig>, IConfigOperations {
|
||||
}
|
||||
|
||||
public async Task<InstanceConfig> Fetch() {
|
||||
// TODO: cache this for some period
|
||||
var key = _context.ServiceConfiguration.OneFuzzInstanceName ?? throw new Exception("Environment variable ONEFUZZ_INSTANCE_NAME is not set");
|
||||
var config = await GetEntityAsync(key, key);
|
||||
return config;
|
||||
|
93
src/ApiService/ApiService/onefuzzlib/RequestAccess.cs
Normal file
93
src/ApiService/ApiService/onefuzzlib/RequestAccess.cs
Normal file
@ -0,0 +1,93 @@
|
||||
|
||||
using System.Net.Http;
|
||||
|
||||
namespace Microsoft.OneFuzz.Service;
|
||||
|
||||
public class RequestAccess {
|
||||
private readonly Node _root = new();
|
||||
|
||||
public record Rules(IReadOnlyList<Guid> AllowedGroupsIds);
|
||||
record Node(
|
||||
// HTTP Method -> Rules
|
||||
Dictionary<HttpMethod, Rules> Rules,
|
||||
// Path Segment -> Node
|
||||
Dictionary<string, Node> Children) {
|
||||
public Node() : this(new(), new()) { }
|
||||
}
|
||||
|
||||
private void AddUri(IEnumerable<HttpMethod> methods, string path, Rules rules) {
|
||||
var segments = path.Split('/', StringSplitOptions.RemoveEmptyEntries);
|
||||
if (!segments.Any()) {
|
||||
return;
|
||||
}
|
||||
|
||||
var currentNode = _root;
|
||||
var currentSegmentIndex = 0;
|
||||
|
||||
while (currentSegmentIndex < segments.Length) {
|
||||
var currentSegment = segments[currentSegmentIndex];
|
||||
if (currentNode.Children.ContainsKey(currentSegment)) {
|
||||
currentNode = currentNode.Children[currentSegment];
|
||||
currentSegmentIndex++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// we found a node matching this exact path
|
||||
// This means that there is an existing rule causing a conflict
|
||||
if (currentSegmentIndex == segments.Length) {
|
||||
var conflictingMethod = methods.FirstOrDefault(m => currentNode.Rules.ContainsKey(m));
|
||||
if (conflictingMethod is not null) {
|
||||
throw new RuleConflictException($"Conflicting rules on {conflictingMethod} {path}");
|
||||
}
|
||||
}
|
||||
|
||||
while (currentSegmentIndex < segments.Length) {
|
||||
var currentSegment = segments[currentSegmentIndex];
|
||||
currentNode = currentNode.Children[currentSegment] = new Node();
|
||||
currentSegmentIndex++;
|
||||
}
|
||||
|
||||
foreach (var method in methods) {
|
||||
currentNode.Rules[method] = rules;
|
||||
}
|
||||
}
|
||||
|
||||
public static RequestAccess Build(IDictionary<string, ApiAccessRule> rules) {
|
||||
var result = new RequestAccess();
|
||||
foreach (var (endpoint, rule) in rules) {
|
||||
result.AddUri(rule.Methods.Select(x => new HttpMethod(x)), endpoint, new Rules(rule.AllowedGroups));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
public Rules? GetMatchingRules(HttpMethod method, string path) {
|
||||
var segments = path.Split("/", StringSplitOptions.RemoveEmptyEntries);
|
||||
|
||||
var currentNode = _root;
|
||||
currentNode.Rules.TryGetValue(method, out var currentRule);
|
||||
|
||||
foreach (var currentSegment in segments) {
|
||||
if (currentNode.Children.TryGetValue(currentSegment, out var node)) {
|
||||
currentNode = node;
|
||||
} else if (currentNode.Children.TryGetValue("*", out var starNode)) {
|
||||
currentNode = starNode;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
|
||||
if (currentNode.Rules.TryGetValue(method, out var rule)) {
|
||||
currentRule = rule;
|
||||
}
|
||||
}
|
||||
|
||||
return currentRule;
|
||||
}
|
||||
}
|
||||
|
||||
public sealed class RuleConflictException : Exception {
|
||||
public RuleConflictException(string? message) : base(message) {
|
||||
}
|
||||
}
|
@ -1,4 +1,5 @@
|
||||
using System;
|
||||
using System.Net.Http;
|
||||
using System.Threading.Tasks;
|
||||
using Azure.Core;
|
||||
using Azure.Identity;
|
||||
@ -50,7 +51,11 @@ class TestCreds : ICreds {
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public Guid GetScalesetPrincipalId() {
|
||||
public Task<Guid> GetScalesetPrincipalId() {
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public Task<T> QueryMicrosoftGraph<T>(HttpMethod method, string resource) {
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
}
|
||||
|
@ -334,6 +334,9 @@ namespace Tests {
|
||||
|
||||
public static Arbitrary<PoolName> PoolName { get; } = OrmGenerators.PoolNameGen.ToArbitrary();
|
||||
|
||||
public static Arbitrary<IReadOnlyList<T>> ReadOnlyList<T>()
|
||||
=> Arb.Default.List<T>().Convert(x => (IReadOnlyList<T>)x, x => (List<T>)x);
|
||||
|
||||
public static Arbitrary<Version> Version() {
|
||||
return Arb.From(OrmGenerators.Version());
|
||||
}
|
||||
|
152
src/ApiService/Tests/RequestAccessTests.cs
Normal file
152
src/ApiService/Tests/RequestAccessTests.cs
Normal file
@ -0,0 +1,152 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Net.Http;
|
||||
using Microsoft.OneFuzz.Service;
|
||||
using Xunit;
|
||||
|
||||
namespace Tests;
|
||||
|
||||
public class RequestAccessTests {
|
||||
|
||||
[Fact]
|
||||
public void TestEmpty() {
|
||||
var requestAccess1 = RequestAccess.Build(new Dictionary<string, ApiAccessRule>());
|
||||
var rules1 = requestAccess1.GetMatchingRules(HttpMethod.Get, "a/b/c");
|
||||
Assert.Null(rules1);
|
||||
|
||||
var requestAccess2 = RequestAccess.Build(
|
||||
new Dictionary<string, ApiAccessRule>{
|
||||
{ "a/b/c", new ApiAccessRule(
|
||||
Methods: new[]{"get"},
|
||||
AllowedGroups: new[]{Guid.NewGuid()})}});
|
||||
|
||||
var rules2 = requestAccess2.GetMatchingRules(HttpMethod.Get, "");
|
||||
Assert.Null(rules2);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TestExactMatch() {
|
||||
var guid1 = Guid.NewGuid();
|
||||
var requestAccess = RequestAccess.Build(
|
||||
new Dictionary<string, ApiAccessRule>{
|
||||
{ "a/b/c", new ApiAccessRule(
|
||||
Methods: new[]{"get"},
|
||||
AllowedGroups: new []{guid1})}});
|
||||
|
||||
var rules1 = requestAccess.GetMatchingRules(HttpMethod.Get, "a/b/c");
|
||||
Assert.NotNull(rules1);
|
||||
var foundGuid = Assert.Single(rules1!.AllowedGroupsIds);
|
||||
Assert.Equal(guid1, foundGuid);
|
||||
|
||||
var rules2 = requestAccess.GetMatchingRules(HttpMethod.Get, "b/b/e");
|
||||
Assert.Null(rules2);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TestWildcard() {
|
||||
var guid1 = Guid.NewGuid();
|
||||
var requestAccess = RequestAccess.Build(
|
||||
new Dictionary<string, ApiAccessRule>{
|
||||
{ "b/*/c", new ApiAccessRule(
|
||||
Methods: new[]{"get"},
|
||||
AllowedGroups: new []{guid1})}});
|
||||
|
||||
var rules = requestAccess.GetMatchingRules(HttpMethod.Get, "b/b/c");
|
||||
Assert.NotNull(rules);
|
||||
var foundGuid = Assert.Single(rules!.AllowedGroupsIds);
|
||||
Assert.Equal(guid1, foundGuid);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TestAddingRuleOnSamePath() {
|
||||
Assert.Throws<RuleConflictException>(() => {
|
||||
var guid1 = Guid.NewGuid();
|
||||
RequestAccess.Build(
|
||||
new Dictionary<string, ApiAccessRule>{
|
||||
{ "a/b/c", new ApiAccessRule(
|
||||
Methods: new[]{"get"},
|
||||
AllowedGroups: new []{guid1})},
|
||||
{ "a/b/c/", new ApiAccessRule(
|
||||
Methods: new[]{"get"},
|
||||
AllowedGroups: Array.Empty<Guid>())}});
|
||||
});
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TestPriority() {
|
||||
// The most specific rules takes priority over the ones containing a wildcard
|
||||
var guid1 = Guid.NewGuid();
|
||||
var guid2 = Guid.NewGuid();
|
||||
var requestAccess = RequestAccess.Build(
|
||||
new Dictionary<string, ApiAccessRule>{
|
||||
{ "a/*/c", new ApiAccessRule(
|
||||
Methods: new[]{"get"},
|
||||
AllowedGroups: new []{guid1})},
|
||||
{ "a/b/c", new ApiAccessRule(
|
||||
Methods: new[]{"get"},
|
||||
AllowedGroups: new[]{guid2})}});
|
||||
|
||||
var rules = requestAccess.GetMatchingRules(HttpMethod.Get, "a/b/c");
|
||||
Assert.NotNull(rules);
|
||||
Assert.Equal(guid2, Assert.Single(rules!.AllowedGroupsIds)); // should match the most specific rule
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TestInheritRule() {
|
||||
// if a path has no specific rule. it inherits from the parents
|
||||
// /a/b/c inherit from a/b
|
||||
var guid1 = Guid.NewGuid();
|
||||
var guid2 = Guid.NewGuid();
|
||||
var guid3 = Guid.NewGuid();
|
||||
var requestAccess = RequestAccess.Build(
|
||||
new Dictionary<string, ApiAccessRule>{
|
||||
{ "a/b/c", new ApiAccessRule(
|
||||
Methods: new[]{"get"},
|
||||
AllowedGroups: new []{guid1})},
|
||||
{ "f/*/c", new ApiAccessRule(
|
||||
Methods: new[]{"get"},
|
||||
AllowedGroups: new[]{guid2})},
|
||||
{ "a/b", new ApiAccessRule(
|
||||
Methods: new[]{"post"},
|
||||
AllowedGroups: new []{guid3})}});
|
||||
|
||||
// should inherit a/b/c
|
||||
var rules1 = requestAccess.GetMatchingRules(HttpMethod.Get, "a/b/c/d");
|
||||
Assert.NotNull(rules1);
|
||||
Assert.Equal(guid1, Assert.Single(rules1!.AllowedGroupsIds));
|
||||
|
||||
// should inherit f/*/c
|
||||
var rules2 = requestAccess.GetMatchingRules(HttpMethod.Get, "f/b/c/d");
|
||||
Assert.NotNull(rules2);
|
||||
Assert.Equal(guid2, Assert.Single(rules2!.AllowedGroupsIds));
|
||||
|
||||
// post should inherit a/b
|
||||
var rules3 = requestAccess.GetMatchingRules(HttpMethod.Post, "a/b/c/d");
|
||||
Assert.NotNull(rules3);
|
||||
Assert.Equal(guid3, Assert.Single(rules3!.AllowedGroupsIds));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TestOverrideRule() {
|
||||
var guid1 = Guid.NewGuid();
|
||||
var guid2 = Guid.NewGuid();
|
||||
var requestAccess = RequestAccess.Build(
|
||||
new Dictionary<string, ApiAccessRule>{
|
||||
{ "a/b/c", new ApiAccessRule(
|
||||
Methods: new[]{"get"},
|
||||
AllowedGroups: new []{guid1})},
|
||||
{ "a/b/c/d", new ApiAccessRule(
|
||||
Methods: new[]{"get"},
|
||||
AllowedGroups: new []{guid2})}});
|
||||
|
||||
// should inherit a/b/c
|
||||
var rules1 = requestAccess.GetMatchingRules(HttpMethod.Get, "a/b/c");
|
||||
Assert.NotNull(rules1);
|
||||
Assert.Equal(guid1, Assert.Single(rules1!.AllowedGroupsIds));
|
||||
|
||||
// should inherit a/b/c/d
|
||||
var rules2 = requestAccess.GetMatchingRules(HttpMethod.Get, "a/b/c/d");
|
||||
Assert.NotNull(rules2);
|
||||
Assert.Equal(guid2, Assert.Single(rules2!.AllowedGroupsIds));
|
||||
}
|
||||
}
|
@ -94,6 +94,8 @@ FUNC_TOOLS_ERROR = (
|
||||
"https://github.com/Azure/azure-functions-core-tools#installing"
|
||||
)
|
||||
|
||||
DOTNET_APPLICATION_SUFFIX = "-net"
|
||||
|
||||
logger = logging.getLogger("deploy")
|
||||
|
||||
|
||||
@ -289,31 +291,49 @@ class Client:
|
||||
"cli_password", object_id, self.get_subscription_id()
|
||||
)
|
||||
|
||||
def get_instance_url(self) -> str:
|
||||
def get_instance_urls(self) -> List[str]:
|
||||
# The url to access the instance
|
||||
# This also represents the legacy identifier_uris of the application
|
||||
# registration
|
||||
if self.multi_tenant_domain:
|
||||
return "https://%s/%s" % (
|
||||
self.multi_tenant_domain,
|
||||
self.application_name,
|
||||
)
|
||||
return [
|
||||
"https://%s/%s" % (self.multi_tenant_domain, name)
|
||||
for name in [
|
||||
self.application_name,
|
||||
self.application_name + DOTNET_APPLICATION_SUFFIX,
|
||||
]
|
||||
]
|
||||
else:
|
||||
return "https://%s.azurewebsites.net" % self.application_name
|
||||
return [
|
||||
"https://%s.azurewebsites.net" % name
|
||||
for name in [
|
||||
self.application_name,
|
||||
self.application_name + DOTNET_APPLICATION_SUFFIX,
|
||||
]
|
||||
]
|
||||
|
||||
def get_identifier_url(self) -> str:
|
||||
def get_identifier_urls(self) -> List[str]:
|
||||
# This is used to identify the application registration via the
|
||||
# identifier_uris field. Depending on the environment this value needs
|
||||
# to be from an approved domain The format of this value is derived
|
||||
# from the default value proposed by azure when creating an application
|
||||
# registration api://{guid}/...
|
||||
if self.multi_tenant_domain:
|
||||
return "api://%s/%s" % (
|
||||
self.multi_tenant_domain,
|
||||
self.application_name,
|
||||
)
|
||||
return [
|
||||
"api://%s/%s" % (self.multi_tenant_domain, name)
|
||||
for name in [
|
||||
self.application_name,
|
||||
self.application_name + DOTNET_APPLICATION_SUFFIX,
|
||||
]
|
||||
]
|
||||
else:
|
||||
return "api://%s.azurewebsites.net" % self.application_name
|
||||
return [
|
||||
"api://%s.azurewebsites.net" % name
|
||||
for name in [
|
||||
self.application_name,
|
||||
self.application_name + DOTNET_APPLICATION_SUFFIX,
|
||||
]
|
||||
]
|
||||
|
||||
def get_signin_audience(self) -> str:
|
||||
# https://docs.microsoft.com/en-us/azure/active-directory/develop/supported-accounts-validation
|
||||
@ -368,7 +388,7 @@ class Client:
|
||||
|
||||
params = {
|
||||
"displayName": self.application_name,
|
||||
"identifierUris": [self.get_identifier_url()],
|
||||
"identifierUris": self.get_identifier_urls(),
|
||||
"signInAudience": self.get_signin_audience(),
|
||||
"appRoles": app_roles,
|
||||
"api": {
|
||||
@ -391,7 +411,8 @@ class Client:
|
||||
"enableIdTokenIssuance": True,
|
||||
},
|
||||
"redirectUris": [
|
||||
f"{self.get_instance_url()}/.auth/login/aad/callback"
|
||||
f"{url}/.auth/login/aad/callback"
|
||||
for url in self.get_instance_urls()
|
||||
],
|
||||
},
|
||||
"requiredResourceAccess": [
|
||||
@ -452,12 +473,14 @@ class Client:
|
||||
try_sp_create()
|
||||
|
||||
else:
|
||||
existing_role_values = [app_role["value"] for app_role in app["appRoles"]]
|
||||
api_id = self.get_identifier_url()
|
||||
|
||||
if api_id not in app["identifierUris"]:
|
||||
identifier_uris = app["identifierUris"]
|
||||
identifier_uris.append(api_id)
|
||||
identifier_uris: List[str] = app["identifierUris"]
|
||||
api_ids = [
|
||||
id for id in self.get_identifier_urls() if id not in identifier_uris
|
||||
]
|
||||
|
||||
if len(api_ids) > 0:
|
||||
identifier_uris.extend(api_ids)
|
||||
query_microsoft_graph(
|
||||
method="PATCH",
|
||||
resource=f"applications/{app['id']}",
|
||||
@ -465,6 +488,8 @@ class Client:
|
||||
subscription=self.get_subscription_id(),
|
||||
)
|
||||
|
||||
existing_role_values = [app_role["value"] for app_role in app["appRoles"]]
|
||||
|
||||
has_missing_roles = any(
|
||||
[role["value"] not in existing_role_values for role in app_roles]
|
||||
)
|
||||
@ -569,10 +594,9 @@ class Client:
|
||||
"%Y-%m-%dT%H:%M:%SZ"
|
||||
)
|
||||
|
||||
app_func_audiences = [
|
||||
self.get_identifier_url(),
|
||||
self.get_instance_url(),
|
||||
]
|
||||
app_func_audiences = self.get_identifier_urls().copy()
|
||||
app_func_audiences.extend(self.get_instance_urls())
|
||||
|
||||
if self.multi_tenant_domain:
|
||||
# clear the value in the Issuer Url field:
|
||||
# https://docs.microsoft.com/en-us/sharepoint/dev/spfx/use-aadhttpclient-enterpriseapi-multitenant
|
||||
@ -650,7 +674,7 @@ class Client:
|
||||
if self.upgrade:
|
||||
logger.info("Upgrading: Skipping assignment of current user to app role")
|
||||
return
|
||||
logger.info("assinging user access to service principal")
|
||||
logger.info("assigning user access to service principal")
|
||||
app = get_application(
|
||||
display_name=self.application_name,
|
||||
subscription_id=self.get_subscription_id(),
|
||||
@ -1049,7 +1073,7 @@ class Client:
|
||||
"azure",
|
||||
"functionapp",
|
||||
"publish",
|
||||
self.application_name + "-net",
|
||||
self.application_name + DOTNET_APPLICATION_SUFFIX,
|
||||
"--no-build",
|
||||
],
|
||||
env=dict(os.environ, CLI_DEBUG="1"),
|
||||
@ -1107,7 +1131,7 @@ class Client:
|
||||
"appsettings",
|
||||
"set",
|
||||
"--name",
|
||||
self.application_name + "-net",
|
||||
self.application_name + DOTNET_APPLICATION_SUFFIX,
|
||||
"--resource-group",
|
||||
self.application_name,
|
||||
"--settings",
|
||||
|
Reference in New Issue
Block a user