diff --git a/src/cli/examples/azure-functions-example/info/__init__.py b/src/cli/examples/azure-functions-example/info/__init__.py index 4de729cda..bdb3418a5 100644 --- a/src/cli/examples/azure-functions-example/info/__init__.py +++ b/src/cli/examples/azure-functions-example/info/__init__.py @@ -12,8 +12,6 @@ def main(req: func.HttpRequest) -> func.HttpResponse: o = Onefuzz() o.config( endpoint=os.environ.get("ONEFUZZ_ENDPOINT"), - override_authority=os.environ.get("ONEFUZZ_AUTHORITY"), - client_id=os.environ.get("ONEFUZZ_CLIENT_ID"), ) info = o.info.get() return func.HttpResponse(info.json()) diff --git a/src/cli/onefuzz/api.py b/src/cli/onefuzz/api.py index cad2f15cb..8fe5d66e6 100644 --- a/src/cli/onefuzz/api.py +++ b/src/cli/onefuzz/api.py @@ -40,11 +40,7 @@ from .ssh import build_ssh_command, ssh_connect, temp_file UUID_EXPANSION = TypeVar("UUID_EXPANSION", UUID, str) -DEFAULT = BackendConfig( - authority="", - client_id="", - tenant_domain="", -) +DEFAULT = BackendConfig(endpoint="") # This was generated randomly and should be preserved moving forwards ONEFUZZ_GUID_NAMESPACE = uuid.UUID("27f25e3f-6544-4b69-b309-9b096c5a9cbc") @@ -1310,7 +1306,7 @@ class Pool(Endpoint): raise Exception("Missing AgentConfig in response") config = pool.config - if not pool.managed: + if not pool.managed and self.onefuzz._backend.config.authority: config.client_credentials = models.ClientCredentials( # nosec client_id=uuid.UUID(int=0), client_secret="", @@ -1894,9 +1890,6 @@ class Onefuzz: def config( self, endpoint: Optional[str] = None, - override_authority: Optional[str] = None, - client_id: Optional[str] = None, - override_tenant_domain: Optional[str] = None, enable_feature: Optional[PreviewFeature] = None, reset: Optional[bool] = None, ) -> BackendConfig: @@ -1904,9 +1897,7 @@ class Onefuzz: self.logger.debug("set config") if reset: - self._backend.config = BackendConfig( - authority="", client_id="", tenant_domain="" - ) + self._backend.config = BackendConfig(endpoint="") if endpoint is not None: # The normal path for calling the API always uses the oauth2 workflow, @@ -1922,17 +1913,12 @@ class Onefuzz: "Missing HTTP Authentication" ) self._backend.config.endpoint = endpoint - if client_id is not None: - self._backend.config.client_id = client_id - if override_authority is not None: - self._backend.config.authority = override_authority + if enable_feature: self._backend.enable_feature(enable_feature.name) - if override_tenant_domain is not None: - self._backend.config.tenant_domain = override_tenant_domain + self._backend.app = None self._backend.save_config() - data = self._backend.config.copy(deep=True) if not data.endpoint: diff --git a/src/cli/onefuzz/backend.py b/src/cli/onefuzz/backend.py index 4811d9188..3e91938c2 100644 --- a/src/cli/onefuzz/backend.py +++ b/src/cli/onefuzz/backend.py @@ -12,7 +12,6 @@ import sys import tempfile import time from dataclasses import asdict, is_dataclass -from datetime import datetime, timedelta from enum import Enum from typing import ( Any, @@ -33,7 +32,7 @@ import msal import requests from azure.storage.blob import ContainerClient from onefuzztypes import responses -from pydantic import BaseModel, Field +from pydantic import BaseModel from requests import Response from tenacity import RetryCallState, retry from tenacity.retry import retry_if_exception_type @@ -93,20 +92,26 @@ def check_application_error(response: requests.Response) -> None: class BackendConfig(BaseModel): - authority: str - client_id: str - endpoint: Optional[str] - features: Set[str] = Field(default_factory=set) - tenant_domain: str - expires_on: datetime = datetime.utcnow() + timedelta(hours=24) + authority: Optional[str] + client_id: Optional[str] + endpoint: str + features: Optional[Set[str]] + tenant_domain: Optional[str] def get_multi_tenant_domain(self) -> Optional[str]: - if "https://login.microsoftonline.com/common" in self.authority: + if ( + self.authority + and "https://login.microsoftonline.com/common" in self.authority + ): return self.tenant_domain else: return None +class CacheConfig(BaseModel): + endpoint: Optional[str] + + class Backend: def __init__( self, @@ -129,10 +134,14 @@ class Backend: atexit.register(self.save_cache) def enable_feature(self, name: str) -> None: + if not self.config.features: + self.config.features = Set[str]() self.config.features.add(name) def is_feature_enabled(self, name: str) -> bool: - return name in self.config.features + if self.config.features: + return name in self.config.features + return False def load_config(self) -> None: if os.path.exists(self.config_path): @@ -143,7 +152,8 @@ class Backend: def save_config(self) -> None: os.makedirs(os.path.dirname(self.config_path), exist_ok=True) with open(self.config_path, "w") as handle: - handle.write(self.config.json(indent=4, exclude_none=True)) + endpoint_cache = {"endpoint": f"{self.config.endpoint}"} + handle.write(json.dumps(endpoint_cache, indent=4, sort_keys=True)) def init_cache(self) -> None: # Ensure the token_path directory exists @@ -331,15 +341,13 @@ class Backend: endpoint_params = responses.Config.parse_obj(response.json()) # Will override values in storage w/ provided values for SP use - if self.config.client_id == "": + if not self.config.client_id: self.config.client_id = endpoint_params.client_id - if self.config.authority == "": + if not self.config.authority: self.config.authority = endpoint_params.authority - if self.config.tenant_domain == "": + if not self.config.tenant_domain: self.config.tenant_domain = endpoint_params.tenant_domain - self.save_config() - def request( self, method: str, @@ -353,17 +361,9 @@ class Backend: if not endpoint: raise Exception("endpoint not configured") - # If file expires, remove and force user to reset - if datetime.utcnow() > self.config.expires_on: - os.remove(self.config_path) - self.config = BackendConfig( - endpoint=endpoint, authority="", client_id="", tenant_domain="" - ) - url = endpoint + "/api/" + path - - if self.config.client_id == "" or ( - self.config.authority == "" and self.config.tenant_domain == "" + if not self.config.client_id or ( + not self.config.authority and not self.config.tenant_domain ): self.config_params() headers = self.headers()