diff --git a/src/ApiService/ApiService/AgentEvents.cs b/src/ApiService/ApiService/AgentEvents.cs index 2f1dff6a8..3c9e1d6f6 100644 --- a/src/ApiService/ApiService/AgentEvents.cs +++ b/src/ApiService/ApiService/AgentEvents.cs @@ -17,8 +17,8 @@ public class AgentEvents { private static readonly EntityConverter _entityConverter = new(); - // [Function("AgentEvents")] - public async Async.Task Run([HttpTrigger("post")] HttpRequestData req) { + [Function("AgentEvents")] + public async Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "POST")] HttpRequestData req) { var request = await RequestHandling.ParseRequest(req); if (!request.IsOk || request.OkV == null) { return await _context.RequestHandling.NotOk(req, request.ErrorV, context: "node event"); diff --git a/src/ApiService/ApiService/Download.cs b/src/ApiService/ApiService/Download.cs index 5f5516e9b..2eb4bf71f 100644 --- a/src/ApiService/ApiService/Download.cs +++ b/src/ApiService/ApiService/Download.cs @@ -14,8 +14,8 @@ public class Download { _context = context; } - // [Function("Download")] - public Async.Task Run([HttpTrigger("GET")] HttpRequestData req) + [Function("Download")] + public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET")] HttpRequestData req) => _auth.CallIfUser(req, Get); private async Async.Task Get(HttpRequestData req) { diff --git a/src/ApiService/ApiService/GroupMembershipChecker.cs b/src/ApiService/ApiService/GroupMembershipChecker.cs new file mode 100644 index 000000000..f1b7737c0 --- /dev/null +++ b/src/ApiService/ApiService/GroupMembershipChecker.cs @@ -0,0 +1,44 @@ +using System.Net.Http; +using System.Threading.Tasks; + +namespace Microsoft.OneFuzz.Service; + +abstract class GroupMembershipChecker { + protected abstract Async.Task> GetGroups(Guid memberId); + + public async Async.Task IsMember(IEnumerable groupIds, Guid memberId) { + if (groupIds.Contains(memberId)) { + return true; + } + + var memberGroups = await GetGroups(memberId); + if (groupIds.Any(memberGroups.Contains)) { + return true; + } + + return false; + } +} + +class AzureADGroupMembership : GroupMembershipChecker { + private readonly ICreds _creds; + public AzureADGroupMembership(ICreds creds) => _creds = creds; + protected override async Task> GetGroups(Guid memberId) => + await _creds.QueryMicrosoftGraph>(HttpMethod.Get, $"users/{memberId}/transitiveMemberOf"); +} + +class StaticGroupMembership : GroupMembershipChecker { + private readonly Dictionary> _memberships; + public StaticGroupMembership(IDictionary memberships) { + _memberships = memberships.ToDictionary(kvp => kvp.Key, kvp => kvp.Value.ToList()); + } + + protected override Task> GetGroups(Guid memberId) { + var result = Enumerable.Empty(); + if (_memberships.TryGetValue(memberId, out var found)) { + result = found; + } + + return Async.Task.FromResult(result); + } +} diff --git a/src/ApiService/ApiService/Info.cs b/src/ApiService/ApiService/Info.cs index 8621c042a..3a4d6b413 100644 --- a/src/ApiService/ApiService/Info.cs +++ b/src/ApiService/ApiService/Info.cs @@ -31,12 +31,13 @@ public class Info { var asm = Assembly.GetExecutingAssembly(); var gitVersion = ReadResource(asm, "ApiService.onefuzzlib.git.version"); var buildId = ReadResource(asm, "ApiService.onefuzzlib.build.id"); + var versionString = asm.GetCustomAttribute()?.InformationalVersion; return new InfoResponse( ResourceGroup: resourceGroup, Subscription: subscription, Region: region, - Versions: new Dictionary { { "onefuzz", new(gitVersion, buildId, config.OneFuzzVersion) } }, + Versions: new Dictionary { { "onefuzz", new(gitVersion, buildId, versionString ?? "") } }, InstanceId: await _context.Containers.GetInstanceId(), InsightsAppid: config.ApplicationInsightsAppId, InsightsInstrumentationKey: config.ApplicationInsightsInstrumentationKey); @@ -50,13 +51,13 @@ public class Info { } using var sr = new StreamReader(r); - return sr.ReadToEnd(); + return sr.ReadToEnd().Trim(); } private async Async.Task GetResponse(HttpRequestData req) => await RequestHandling.Ok(req, await _response.Value); - // [Function("Info")] - public Async.Task Run([HttpTrigger("GET")] HttpRequestData req) + [Function("Info")] + public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET")] HttpRequestData req) => _auth.CallIfUser(req, GetResponse); } diff --git a/src/ApiService/ApiService/Node.cs b/src/ApiService/ApiService/Node.cs index 65ee6f222..ca445880c 100644 --- a/src/ApiService/ApiService/Node.cs +++ b/src/ApiService/ApiService/Node.cs @@ -18,8 +18,8 @@ public class NodeFunction { private static readonly EntityConverter _entityConverter = new(); - // [Function("Node") - public Async.Task Run([HttpTrigger("GET", "PATCH", "POST", "DELETE")] HttpRequestData req) { + [Function("Node")] + public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET", "PATCH", "POST", "DELETE")] HttpRequestData req) { return _auth.CallIfUser(req, r => r.Method switch { "GET" => Get(r), "PATCH" => Patch(r), diff --git a/src/ApiService/ApiService/OneFuzzTypes/Model.cs b/src/ApiService/ApiService/OneFuzzTypes/Model.cs index b78d1539b..e0c28cfbb 100644 --- a/src/ApiService/ApiService/OneFuzzTypes/Model.cs +++ b/src/ApiService/ApiService/OneFuzzTypes/Model.cs @@ -299,8 +299,8 @@ public record NetworkSecurityGroupConfig( } public record ApiAccessRule( - string[] Methods, - Guid[] AllowedGroups + IReadOnlyList Methods, + IReadOnlyList AllowedGroups ); //# initial set of admins can only be set during deployment. diff --git a/src/ApiService/ApiService/onefuzzlib/Creds.cs b/src/ApiService/ApiService/onefuzzlib/Creds.cs index c1e21de81..bb99642ff 100644 --- a/src/ApiService/ApiService/onefuzzlib/Creds.cs +++ b/src/ApiService/ApiService/onefuzzlib/Creds.cs @@ -1,4 +1,6 @@ -using System.Text.Json; +using System.Net.Http; +using System.Net.Http.Json; +using System.Threading.Tasks; using Azure.Core; using Azure.Identity; using Azure.ResourceManager; @@ -24,18 +26,21 @@ public interface ICreds { public Async.Task GetBaseRegion(); public Uri GetInstanceUrl(); - Guid GetScalesetPrincipalId(); + public Async.Task GetScalesetPrincipalId(); + public Async.Task QueryMicrosoftGraph(HttpMethod method, string resource); } public class Creds : ICreds { private readonly ArmClient _armClient; private readonly DefaultAzureCredential _azureCredential; private readonly IServiceConfig _config; + private readonly IHttpClientFactory _httpClientFactory; public ArmClient ArmClient => _armClient; - public Creds(IServiceConfig config) { + public Creds(IServiceConfig config, IHttpClientFactory httpClientFactory) { _config = config; + _httpClientFactory = httpClientFactory; _azureCredential = new DefaultAzureCredential(); _armClient = new ArmClient(this.GetIdentity(), this.GetSubscription()); } @@ -88,11 +93,14 @@ public class Creds : ICreds { return new Uri($"https://{GetInstanceName()}.azurewebsites.net"); } - public Guid GetScalesetPrincipalId() { - var uid = ArmClient.GetGenericResource( - new ResourceIdentifier(GetScalesetIdentityResourcePath()) - ); - var principalId = JsonSerializer.Deserialize(uid.Data.Properties.ToString())?.RootElement.GetProperty("principalId").GetString()!; + public record ScaleSetIdentity(string principalId); + + public async Async.Task GetScalesetPrincipalId() { + var path = GetScalesetIdentityResourcePath(); + var uid = ArmClient.GetGenericResource(new ResourceIdentifier(path)); + + var resource = await uid.GetAsync(); + var principalId = resource.Value.Data.Properties.ToObjectFromJson().principalId; return new Guid(principalId); } @@ -102,4 +110,44 @@ public class Creds : ICreds { return $"{resourceGroupPath}/Microsoft.ManagedIdentity/userAssignedIdentities/{scalesetIdName}"; } + + + // https://docs.microsoft.com/en-us/graph/api/overview?view=graph-rest-1.0 + private static readonly Uri _graphResource = new("https://graph.microsoft.com"); + private static readonly Uri _graphResourceEndpoint = new("https://graph.microsoft.com/v1.0"); + + public async Task QueryMicrosoftGraph(HttpMethod method, string resource) { + var cred = GetIdentity(); + + var scopes = new string[] { $"{_graphResource}/.default" }; + var accessToken = await cred.GetTokenAsync(new TokenRequestContext(scopes)); + + var uri = new Uri($"{_graphResourceEndpoint}/{resource}"); + using var httpClient = _httpClientFactory.CreateClient(); + using var response = await httpClient.SendAsync(new HttpRequestMessage { + Headers = { + {"Authorization", $"Bearer {accessToken.Token}"}, + {"Content-Type", "application/json"}, + }, + Method = method, + RequestUri = uri, + }); + + if (response.IsSuccessStatusCode) { + var result = await response.Content.ReadFromJsonAsync(); + if (result is null) { + throw new GraphQueryException($"invalid data expected a json object: HTTP {response.StatusCode}"); + } + + return result; + } else { + var errorText = await response.Content.ReadAsStringAsync(); + throw new GraphQueryException($"request did not succeed: HTTP {response.StatusCode} - {errorText}"); + } + } +} + +class GraphQueryException : Exception { + public GraphQueryException(string? message) : base(message) { + } } diff --git a/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs b/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs index 67562a2dc..7fec5807d 100644 --- a/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs +++ b/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs @@ -1,4 +1,5 @@ using System.Net; +using System.Net.Http; using Microsoft.Azure.Functions.Worker.Http; namespace Microsoft.OneFuzz.Service; @@ -45,13 +46,12 @@ public class EndpointAuthorization : IEndpointAuthorization { return await Reject(req, token); } - var access = CheckAccess(req); + var access = await CheckAccess(req); if (!access.IsOk) { return await _context.RequestHandling.NotOk(req, access.ErrorV, "access control", HttpStatusCode.Unauthorized); } } - if (await IsAgent(token) && !allowAgent) { return await Reject(req, token); } @@ -132,8 +132,54 @@ public class EndpointAuthorization : IEndpointAuthorization { } } - public OneFuzzResultVoid CheckAccess(HttpRequestData req) { - throw new NotImplementedException(); + public async Async.Task CheckAccess(HttpRequestData req) { + var instanceConfig = await _context.ConfigOperations.Fetch(); + + var rules = GetRules(instanceConfig); + if (rules is null) { + return OneFuzzResultVoid.Ok; + } + + var path = req.Url.AbsolutePath; + var rule = rules.GetMatchingRules(new HttpMethod(req.Method), path); + if (rule is null) { + return OneFuzzResultVoid.Ok; + } + + var memberId = Guid.Parse(req.Headers.GetValues("x-ms-client-principal-id").Single()); + try { + var membershipChecker = CreateGroupMembershipChecker(instanceConfig); + var allowed = await membershipChecker.IsMember(rule.AllowedGroupsIds, memberId); + if (!allowed) { + _log.Error($"unauthorized access: {memberId} is not authorized to access {path}"); + return new Error( + Code: ErrorCode.UNAUTHORIZED, + Errors: new string[] { "not approved to use this endpoint" }); + } else { + return OneFuzzResultVoid.Ok; + } + } catch (Exception ex) { + return new Error( + Code: ErrorCode.UNAUTHORIZED, + Errors: new string[] { "unable to interact with graph", ex.Message }); + } + } + + private GroupMembershipChecker CreateGroupMembershipChecker(InstanceConfig config) { + if (config.GroupMembership is not null) { + return new StaticGroupMembership(config.GroupMembership); + } + + return new AzureADGroupMembership(_context.Creds); + } + + private static RequestAccess? GetRules(InstanceConfig config) { + var accessRules = config?.ApiAccessRules; + if (accessRules is not null) { + return RequestAccess.Build(accessRules); + } + + return null; } public async Async.Task IsAgent(UserInfo tokenData) { @@ -143,7 +189,7 @@ public class EndpointAuthorization : IEndpointAuthorization { return true; } - var principalId = _context.Creds.GetScalesetPrincipalId(); + var principalId = await _context.Creds.GetScalesetPrincipalId(); return principalId == tokenData.ObjectId; } diff --git a/src/ApiService/ApiService/onefuzzlib/InstanceConfig.cs b/src/ApiService/ApiService/onefuzzlib/InstanceConfig.cs index bb410c248..5d4e73f6c 100644 --- a/src/ApiService/ApiService/onefuzzlib/InstanceConfig.cs +++ b/src/ApiService/ApiService/onefuzzlib/InstanceConfig.cs @@ -18,6 +18,7 @@ public class ConfigOperations : Orm, IConfigOperations { } public async Task Fetch() { + // TODO: cache this for some period var key = _context.ServiceConfiguration.OneFuzzInstanceName ?? throw new Exception("Environment variable ONEFUZZ_INSTANCE_NAME is not set"); var config = await GetEntityAsync(key, key); return config; diff --git a/src/ApiService/ApiService/onefuzzlib/RequestAccess.cs b/src/ApiService/ApiService/onefuzzlib/RequestAccess.cs new file mode 100644 index 000000000..57e70e0c6 --- /dev/null +++ b/src/ApiService/ApiService/onefuzzlib/RequestAccess.cs @@ -0,0 +1,93 @@ + +using System.Net.Http; + +namespace Microsoft.OneFuzz.Service; + +public class RequestAccess { + private readonly Node _root = new(); + + public record Rules(IReadOnlyList AllowedGroupsIds); + record Node( + // HTTP Method -> Rules + Dictionary Rules, + // Path Segment -> Node + Dictionary Children) { + public Node() : this(new(), new()) { } + } + + private void AddUri(IEnumerable methods, string path, Rules rules) { + var segments = path.Split('/', StringSplitOptions.RemoveEmptyEntries); + if (!segments.Any()) { + return; + } + + var currentNode = _root; + var currentSegmentIndex = 0; + + while (currentSegmentIndex < segments.Length) { + var currentSegment = segments[currentSegmentIndex]; + if (currentNode.Children.ContainsKey(currentSegment)) { + currentNode = currentNode.Children[currentSegment]; + currentSegmentIndex++; + } else { + break; + } + } + + // we found a node matching this exact path + // This means that there is an existing rule causing a conflict + if (currentSegmentIndex == segments.Length) { + var conflictingMethod = methods.FirstOrDefault(m => currentNode.Rules.ContainsKey(m)); + if (conflictingMethod is not null) { + throw new RuleConflictException($"Conflicting rules on {conflictingMethod} {path}"); + } + } + + while (currentSegmentIndex < segments.Length) { + var currentSegment = segments[currentSegmentIndex]; + currentNode = currentNode.Children[currentSegment] = new Node(); + currentSegmentIndex++; + } + + foreach (var method in methods) { + currentNode.Rules[method] = rules; + } + } + + public static RequestAccess Build(IDictionary rules) { + var result = new RequestAccess(); + foreach (var (endpoint, rule) in rules) { + result.AddUri(rule.Methods.Select(x => new HttpMethod(x)), endpoint, new Rules(rule.AllowedGroups)); + } + + return result; + } + + public Rules? GetMatchingRules(HttpMethod method, string path) { + var segments = path.Split("/", StringSplitOptions.RemoveEmptyEntries); + + var currentNode = _root; + currentNode.Rules.TryGetValue(method, out var currentRule); + + foreach (var currentSegment in segments) { + if (currentNode.Children.TryGetValue(currentSegment, out var node)) { + currentNode = node; + } else if (currentNode.Children.TryGetValue("*", out var starNode)) { + currentNode = starNode; + } else { + break; + } + + if (currentNode.Rules.TryGetValue(method, out var rule)) { + currentRule = rule; + } + } + + return currentRule; + } +} + +public sealed class RuleConflictException : Exception { + public RuleConflictException(string? message) : base(message) { + } +} diff --git a/src/ApiService/IntegrationTests/Fakes/TestCreds.cs b/src/ApiService/IntegrationTests/Fakes/TestCreds.cs index 85eee4c6a..6a43a0fe6 100644 --- a/src/ApiService/IntegrationTests/Fakes/TestCreds.cs +++ b/src/ApiService/IntegrationTests/Fakes/TestCreds.cs @@ -1,4 +1,5 @@ using System; +using System.Net.Http; using System.Threading.Tasks; using Azure.Core; using Azure.Identity; @@ -50,7 +51,11 @@ class TestCreds : ICreds { throw new NotImplementedException(); } - public Guid GetScalesetPrincipalId() { + public Task GetScalesetPrincipalId() { + throw new NotImplementedException(); + } + + public Task QueryMicrosoftGraph(HttpMethod method, string resource) { throw new NotImplementedException(); } } diff --git a/src/ApiService/Tests/OrmModelsTest.cs b/src/ApiService/Tests/OrmModelsTest.cs index 87c0c920a..f185d6f1b 100644 --- a/src/ApiService/Tests/OrmModelsTest.cs +++ b/src/ApiService/Tests/OrmModelsTest.cs @@ -334,6 +334,9 @@ namespace Tests { public static Arbitrary PoolName { get; } = OrmGenerators.PoolNameGen.ToArbitrary(); + public static Arbitrary> ReadOnlyList() + => Arb.Default.List().Convert(x => (IReadOnlyList)x, x => (List)x); + public static Arbitrary Version() { return Arb.From(OrmGenerators.Version()); } diff --git a/src/ApiService/Tests/RequestAccessTests.cs b/src/ApiService/Tests/RequestAccessTests.cs new file mode 100644 index 000000000..be0335980 --- /dev/null +++ b/src/ApiService/Tests/RequestAccessTests.cs @@ -0,0 +1,152 @@ +using System; +using System.Collections.Generic; +using System.Net.Http; +using Microsoft.OneFuzz.Service; +using Xunit; + +namespace Tests; + +public class RequestAccessTests { + + [Fact] + public void TestEmpty() { + var requestAccess1 = RequestAccess.Build(new Dictionary()); + var rules1 = requestAccess1.GetMatchingRules(HttpMethod.Get, "a/b/c"); + Assert.Null(rules1); + + var requestAccess2 = RequestAccess.Build( + new Dictionary{ + { "a/b/c", new ApiAccessRule( + Methods: new[]{"get"}, + AllowedGroups: new[]{Guid.NewGuid()})}}); + + var rules2 = requestAccess2.GetMatchingRules(HttpMethod.Get, ""); + Assert.Null(rules2); + } + + [Fact] + public void TestExactMatch() { + var guid1 = Guid.NewGuid(); + var requestAccess = RequestAccess.Build( + new Dictionary{ + { "a/b/c", new ApiAccessRule( + Methods: new[]{"get"}, + AllowedGroups: new []{guid1})}}); + + var rules1 = requestAccess.GetMatchingRules(HttpMethod.Get, "a/b/c"); + Assert.NotNull(rules1); + var foundGuid = Assert.Single(rules1!.AllowedGroupsIds); + Assert.Equal(guid1, foundGuid); + + var rules2 = requestAccess.GetMatchingRules(HttpMethod.Get, "b/b/e"); + Assert.Null(rules2); + } + + [Fact] + public void TestWildcard() { + var guid1 = Guid.NewGuid(); + var requestAccess = RequestAccess.Build( + new Dictionary{ + { "b/*/c", new ApiAccessRule( + Methods: new[]{"get"}, + AllowedGroups: new []{guid1})}}); + + var rules = requestAccess.GetMatchingRules(HttpMethod.Get, "b/b/c"); + Assert.NotNull(rules); + var foundGuid = Assert.Single(rules!.AllowedGroupsIds); + Assert.Equal(guid1, foundGuid); + } + + [Fact] + public void TestAddingRuleOnSamePath() { + Assert.Throws(() => { + var guid1 = Guid.NewGuid(); + RequestAccess.Build( + new Dictionary{ + { "a/b/c", new ApiAccessRule( + Methods: new[]{"get"}, + AllowedGroups: new []{guid1})}, + { "a/b/c/", new ApiAccessRule( + Methods: new[]{"get"}, + AllowedGroups: Array.Empty())}}); + }); + } + + [Fact] + public void TestPriority() { + // The most specific rules takes priority over the ones containing a wildcard + var guid1 = Guid.NewGuid(); + var guid2 = Guid.NewGuid(); + var requestAccess = RequestAccess.Build( + new Dictionary{ + { "a/*/c", new ApiAccessRule( + Methods: new[]{"get"}, + AllowedGroups: new []{guid1})}, + { "a/b/c", new ApiAccessRule( + Methods: new[]{"get"}, + AllowedGroups: new[]{guid2})}}); + + var rules = requestAccess.GetMatchingRules(HttpMethod.Get, "a/b/c"); + Assert.NotNull(rules); + Assert.Equal(guid2, Assert.Single(rules!.AllowedGroupsIds)); // should match the most specific rule + } + + [Fact] + public void TestInheritRule() { + // if a path has no specific rule. it inherits from the parents + // /a/b/c inherit from a/b + var guid1 = Guid.NewGuid(); + var guid2 = Guid.NewGuid(); + var guid3 = Guid.NewGuid(); + var requestAccess = RequestAccess.Build( + new Dictionary{ + { "a/b/c", new ApiAccessRule( + Methods: new[]{"get"}, + AllowedGroups: new []{guid1})}, + { "f/*/c", new ApiAccessRule( + Methods: new[]{"get"}, + AllowedGroups: new[]{guid2})}, + { "a/b", new ApiAccessRule( + Methods: new[]{"post"}, + AllowedGroups: new []{guid3})}}); + + // should inherit a/b/c + var rules1 = requestAccess.GetMatchingRules(HttpMethod.Get, "a/b/c/d"); + Assert.NotNull(rules1); + Assert.Equal(guid1, Assert.Single(rules1!.AllowedGroupsIds)); + + // should inherit f/*/c + var rules2 = requestAccess.GetMatchingRules(HttpMethod.Get, "f/b/c/d"); + Assert.NotNull(rules2); + Assert.Equal(guid2, Assert.Single(rules2!.AllowedGroupsIds)); + + // post should inherit a/b + var rules3 = requestAccess.GetMatchingRules(HttpMethod.Post, "a/b/c/d"); + Assert.NotNull(rules3); + Assert.Equal(guid3, Assert.Single(rules3!.AllowedGroupsIds)); + } + + [Fact] + public void TestOverrideRule() { + var guid1 = Guid.NewGuid(); + var guid2 = Guid.NewGuid(); + var requestAccess = RequestAccess.Build( + new Dictionary{ + { "a/b/c", new ApiAccessRule( + Methods: new[]{"get"}, + AllowedGroups: new []{guid1})}, + { "a/b/c/d", new ApiAccessRule( + Methods: new[]{"get"}, + AllowedGroups: new []{guid2})}}); + + // should inherit a/b/c + var rules1 = requestAccess.GetMatchingRules(HttpMethod.Get, "a/b/c"); + Assert.NotNull(rules1); + Assert.Equal(guid1, Assert.Single(rules1!.AllowedGroupsIds)); + + // should inherit a/b/c/d + var rules2 = requestAccess.GetMatchingRules(HttpMethod.Get, "a/b/c/d"); + Assert.NotNull(rules2); + Assert.Equal(guid2, Assert.Single(rules2!.AllowedGroupsIds)); + } +} diff --git a/src/deployment/deploy.py b/src/deployment/deploy.py index c7f05f4d3..ae0311290 100644 --- a/src/deployment/deploy.py +++ b/src/deployment/deploy.py @@ -94,6 +94,8 @@ FUNC_TOOLS_ERROR = ( "https://github.com/Azure/azure-functions-core-tools#installing" ) +DOTNET_APPLICATION_SUFFIX = "-net" + logger = logging.getLogger("deploy") @@ -289,31 +291,49 @@ class Client: "cli_password", object_id, self.get_subscription_id() ) - def get_instance_url(self) -> str: + def get_instance_urls(self) -> List[str]: # The url to access the instance # This also represents the legacy identifier_uris of the application # registration if self.multi_tenant_domain: - return "https://%s/%s" % ( - self.multi_tenant_domain, - self.application_name, - ) + return [ + "https://%s/%s" % (self.multi_tenant_domain, name) + for name in [ + self.application_name, + self.application_name + DOTNET_APPLICATION_SUFFIX, + ] + ] else: - return "https://%s.azurewebsites.net" % self.application_name + return [ + "https://%s.azurewebsites.net" % name + for name in [ + self.application_name, + self.application_name + DOTNET_APPLICATION_SUFFIX, + ] + ] - def get_identifier_url(self) -> str: + def get_identifier_urls(self) -> List[str]: # This is used to identify the application registration via the # identifier_uris field. Depending on the environment this value needs # to be from an approved domain The format of this value is derived # from the default value proposed by azure when creating an application # registration api://{guid}/... if self.multi_tenant_domain: - return "api://%s/%s" % ( - self.multi_tenant_domain, - self.application_name, - ) + return [ + "api://%s/%s" % (self.multi_tenant_domain, name) + for name in [ + self.application_name, + self.application_name + DOTNET_APPLICATION_SUFFIX, + ] + ] else: - return "api://%s.azurewebsites.net" % self.application_name + return [ + "api://%s.azurewebsites.net" % name + for name in [ + self.application_name, + self.application_name + DOTNET_APPLICATION_SUFFIX, + ] + ] def get_signin_audience(self) -> str: # https://docs.microsoft.com/en-us/azure/active-directory/develop/supported-accounts-validation @@ -368,7 +388,7 @@ class Client: params = { "displayName": self.application_name, - "identifierUris": [self.get_identifier_url()], + "identifierUris": self.get_identifier_urls(), "signInAudience": self.get_signin_audience(), "appRoles": app_roles, "api": { @@ -391,7 +411,8 @@ class Client: "enableIdTokenIssuance": True, }, "redirectUris": [ - f"{self.get_instance_url()}/.auth/login/aad/callback" + f"{url}/.auth/login/aad/callback" + for url in self.get_instance_urls() ], }, "requiredResourceAccess": [ @@ -452,12 +473,14 @@ class Client: try_sp_create() else: - existing_role_values = [app_role["value"] for app_role in app["appRoles"]] - api_id = self.get_identifier_url() - if api_id not in app["identifierUris"]: - identifier_uris = app["identifierUris"] - identifier_uris.append(api_id) + identifier_uris: List[str] = app["identifierUris"] + api_ids = [ + id for id in self.get_identifier_urls() if id not in identifier_uris + ] + + if len(api_ids) > 0: + identifier_uris.extend(api_ids) query_microsoft_graph( method="PATCH", resource=f"applications/{app['id']}", @@ -465,6 +488,8 @@ class Client: subscription=self.get_subscription_id(), ) + existing_role_values = [app_role["value"] for app_role in app["appRoles"]] + has_missing_roles = any( [role["value"] not in existing_role_values for role in app_roles] ) @@ -569,10 +594,9 @@ class Client: "%Y-%m-%dT%H:%M:%SZ" ) - app_func_audiences = [ - self.get_identifier_url(), - self.get_instance_url(), - ] + app_func_audiences = self.get_identifier_urls().copy() + app_func_audiences.extend(self.get_instance_urls()) + if self.multi_tenant_domain: # clear the value in the Issuer Url field: # https://docs.microsoft.com/en-us/sharepoint/dev/spfx/use-aadhttpclient-enterpriseapi-multitenant @@ -650,7 +674,7 @@ class Client: if self.upgrade: logger.info("Upgrading: Skipping assignment of current user to app role") return - logger.info("assinging user access to service principal") + logger.info("assigning user access to service principal") app = get_application( display_name=self.application_name, subscription_id=self.get_subscription_id(), @@ -1049,7 +1073,7 @@ class Client: "azure", "functionapp", "publish", - self.application_name + "-net", + self.application_name + DOTNET_APPLICATION_SUFFIX, "--no-build", ], env=dict(os.environ, CLI_DEBUG="1"), @@ -1107,7 +1131,7 @@ class Client: "appsettings", "set", "--name", - self.application_name + "-net", + self.application_name + DOTNET_APPLICATION_SUFFIX, "--resource-group", self.application_name, "--settings",