Remove Additional config params - require on each request (#3000)

* Only Overrite Config Cache

* Lint

* Fixing isort.

* Removing expiry.

* Removing import.

* Removing config params.

* Remove bad import.

* Adjusting to type changes.

* Remove whitespace.

* Formatting.

* Formatting.

* null check.

* Formatting.
This commit is contained in:
Noah McGregor Harper
2023-04-11 11:35:09 -07:00
committed by GitHub
parent 77c42930a6
commit 169cef7a06
3 changed files with 31 additions and 47 deletions

View File

@ -12,8 +12,6 @@ def main(req: func.HttpRequest) -> func.HttpResponse:
o = Onefuzz() o = Onefuzz()
o.config( o.config(
endpoint=os.environ.get("ONEFUZZ_ENDPOINT"), endpoint=os.environ.get("ONEFUZZ_ENDPOINT"),
override_authority=os.environ.get("ONEFUZZ_AUTHORITY"),
client_id=os.environ.get("ONEFUZZ_CLIENT_ID"),
) )
info = o.info.get() info = o.info.get()
return func.HttpResponse(info.json()) return func.HttpResponse(info.json())

View File

@ -40,11 +40,7 @@ from .ssh import build_ssh_command, ssh_connect, temp_file
UUID_EXPANSION = TypeVar("UUID_EXPANSION", UUID, str) UUID_EXPANSION = TypeVar("UUID_EXPANSION", UUID, str)
DEFAULT = BackendConfig( DEFAULT = BackendConfig(endpoint="")
authority="",
client_id="",
tenant_domain="",
)
# This was generated randomly and should be preserved moving forwards # This was generated randomly and should be preserved moving forwards
ONEFUZZ_GUID_NAMESPACE = uuid.UUID("27f25e3f-6544-4b69-b309-9b096c5a9cbc") ONEFUZZ_GUID_NAMESPACE = uuid.UUID("27f25e3f-6544-4b69-b309-9b096c5a9cbc")
@ -1310,7 +1306,7 @@ class Pool(Endpoint):
raise Exception("Missing AgentConfig in response") raise Exception("Missing AgentConfig in response")
config = pool.config config = pool.config
if not pool.managed: if not pool.managed and self.onefuzz._backend.config.authority:
config.client_credentials = models.ClientCredentials( # nosec config.client_credentials = models.ClientCredentials( # nosec
client_id=uuid.UUID(int=0), client_id=uuid.UUID(int=0),
client_secret="<client_secret>", client_secret="<client_secret>",
@ -1894,9 +1890,6 @@ class Onefuzz:
def config( def config(
self, self,
endpoint: Optional[str] = None, endpoint: Optional[str] = None,
override_authority: Optional[str] = None,
client_id: Optional[str] = None,
override_tenant_domain: Optional[str] = None,
enable_feature: Optional[PreviewFeature] = None, enable_feature: Optional[PreviewFeature] = None,
reset: Optional[bool] = None, reset: Optional[bool] = None,
) -> BackendConfig: ) -> BackendConfig:
@ -1904,9 +1897,7 @@ class Onefuzz:
self.logger.debug("set config") self.logger.debug("set config")
if reset: if reset:
self._backend.config = BackendConfig( self._backend.config = BackendConfig(endpoint="")
authority="", client_id="", tenant_domain=""
)
if endpoint is not None: if endpoint is not None:
# The normal path for calling the API always uses the oauth2 workflow, # The normal path for calling the API always uses the oauth2 workflow,
@ -1922,17 +1913,12 @@ class Onefuzz:
"Missing HTTP Authentication" "Missing HTTP Authentication"
) )
self._backend.config.endpoint = endpoint self._backend.config.endpoint = endpoint
if client_id is not None:
self._backend.config.client_id = client_id
if override_authority is not None:
self._backend.config.authority = override_authority
if enable_feature: if enable_feature:
self._backend.enable_feature(enable_feature.name) self._backend.enable_feature(enable_feature.name)
if override_tenant_domain is not None:
self._backend.config.tenant_domain = override_tenant_domain
self._backend.app = None self._backend.app = None
self._backend.save_config() self._backend.save_config()
data = self._backend.config.copy(deep=True) data = self._backend.config.copy(deep=True)
if not data.endpoint: if not data.endpoint:

View File

@ -12,7 +12,6 @@ import sys
import tempfile import tempfile
import time import time
from dataclasses import asdict, is_dataclass from dataclasses import asdict, is_dataclass
from datetime import datetime, timedelta
from enum import Enum from enum import Enum
from typing import ( from typing import (
Any, Any,
@ -33,7 +32,7 @@ import msal
import requests import requests
from azure.storage.blob import ContainerClient from azure.storage.blob import ContainerClient
from onefuzztypes import responses from onefuzztypes import responses
from pydantic import BaseModel, Field from pydantic import BaseModel
from requests import Response from requests import Response
from tenacity import RetryCallState, retry from tenacity import RetryCallState, retry
from tenacity.retry import retry_if_exception_type from tenacity.retry import retry_if_exception_type
@ -93,20 +92,26 @@ def check_application_error(response: requests.Response) -> None:
class BackendConfig(BaseModel): class BackendConfig(BaseModel):
authority: str authority: Optional[str]
client_id: str client_id: Optional[str]
endpoint: Optional[str] endpoint: str
features: Set[str] = Field(default_factory=set) features: Optional[Set[str]]
tenant_domain: str tenant_domain: Optional[str]
expires_on: datetime = datetime.utcnow() + timedelta(hours=24)
def get_multi_tenant_domain(self) -> Optional[str]: def get_multi_tenant_domain(self) -> Optional[str]:
if "https://login.microsoftonline.com/common" in self.authority: if (
self.authority
and "https://login.microsoftonline.com/common" in self.authority
):
return self.tenant_domain return self.tenant_domain
else: else:
return None return None
class CacheConfig(BaseModel):
endpoint: Optional[str]
class Backend: class Backend:
def __init__( def __init__(
self, self,
@ -129,10 +134,14 @@ class Backend:
atexit.register(self.save_cache) atexit.register(self.save_cache)
def enable_feature(self, name: str) -> None: def enable_feature(self, name: str) -> None:
if not self.config.features:
self.config.features = Set[str]()
self.config.features.add(name) self.config.features.add(name)
def is_feature_enabled(self, name: str) -> bool: def is_feature_enabled(self, name: str) -> bool:
if self.config.features:
return name in self.config.features return name in self.config.features
return False
def load_config(self) -> None: def load_config(self) -> None:
if os.path.exists(self.config_path): if os.path.exists(self.config_path):
@ -143,7 +152,8 @@ class Backend:
def save_config(self) -> None: def save_config(self) -> None:
os.makedirs(os.path.dirname(self.config_path), exist_ok=True) os.makedirs(os.path.dirname(self.config_path), exist_ok=True)
with open(self.config_path, "w") as handle: with open(self.config_path, "w") as handle:
handle.write(self.config.json(indent=4, exclude_none=True)) endpoint_cache = {"endpoint": f"{self.config.endpoint}"}
handle.write(json.dumps(endpoint_cache, indent=4, sort_keys=True))
def init_cache(self) -> None: def init_cache(self) -> None:
# Ensure the token_path directory exists # Ensure the token_path directory exists
@ -331,15 +341,13 @@ class Backend:
endpoint_params = responses.Config.parse_obj(response.json()) endpoint_params = responses.Config.parse_obj(response.json())
# Will override values in storage w/ provided values for SP use # Will override values in storage w/ provided values for SP use
if self.config.client_id == "": if not self.config.client_id:
self.config.client_id = endpoint_params.client_id self.config.client_id = endpoint_params.client_id
if self.config.authority == "": if not self.config.authority:
self.config.authority = endpoint_params.authority self.config.authority = endpoint_params.authority
if self.config.tenant_domain == "": if not self.config.tenant_domain:
self.config.tenant_domain = endpoint_params.tenant_domain self.config.tenant_domain = endpoint_params.tenant_domain
self.save_config()
def request( def request(
self, self,
method: str, method: str,
@ -353,17 +361,9 @@ class Backend:
if not endpoint: if not endpoint:
raise Exception("endpoint not configured") raise Exception("endpoint not configured")
# If file expires, remove and force user to reset
if datetime.utcnow() > self.config.expires_on:
os.remove(self.config_path)
self.config = BackendConfig(
endpoint=endpoint, authority="", client_id="", tenant_domain=""
)
url = endpoint + "/api/" + path url = endpoint + "/api/" + path
if not self.config.client_id or (
if self.config.client_id == "" or ( not self.config.authority and not self.config.tenant_domain
self.config.authority == "" and self.config.tenant_domain == ""
): ):
self.config_params() self.config_params()
headers = self.headers() headers = self.headers()