Remove Additional config params - require on each request (#3000)

* Only Overrite Config Cache

* Lint

* Fixing isort.

* Removing expiry.

* Removing import.

* Removing config params.

* Remove bad import.

* Adjusting to type changes.

* Remove whitespace.

* Formatting.

* Formatting.

* null check.

* Formatting.
This commit is contained in:
Noah McGregor Harper
2023-04-11 11:35:09 -07:00
committed by GitHub
parent 77c42930a6
commit 169cef7a06
3 changed files with 31 additions and 47 deletions

View File

@ -12,8 +12,6 @@ def main(req: func.HttpRequest) -> func.HttpResponse:
o = Onefuzz()
o.config(
endpoint=os.environ.get("ONEFUZZ_ENDPOINT"),
override_authority=os.environ.get("ONEFUZZ_AUTHORITY"),
client_id=os.environ.get("ONEFUZZ_CLIENT_ID"),
)
info = o.info.get()
return func.HttpResponse(info.json())

View File

@ -40,11 +40,7 @@ from .ssh import build_ssh_command, ssh_connect, temp_file
UUID_EXPANSION = TypeVar("UUID_EXPANSION", UUID, str)
DEFAULT = BackendConfig(
authority="",
client_id="",
tenant_domain="",
)
DEFAULT = BackendConfig(endpoint="")
# This was generated randomly and should be preserved moving forwards
ONEFUZZ_GUID_NAMESPACE = uuid.UUID("27f25e3f-6544-4b69-b309-9b096c5a9cbc")
@ -1310,7 +1306,7 @@ class Pool(Endpoint):
raise Exception("Missing AgentConfig in response")
config = pool.config
if not pool.managed:
if not pool.managed and self.onefuzz._backend.config.authority:
config.client_credentials = models.ClientCredentials( # nosec
client_id=uuid.UUID(int=0),
client_secret="<client_secret>",
@ -1894,9 +1890,6 @@ class Onefuzz:
def config(
self,
endpoint: Optional[str] = None,
override_authority: Optional[str] = None,
client_id: Optional[str] = None,
override_tenant_domain: Optional[str] = None,
enable_feature: Optional[PreviewFeature] = None,
reset: Optional[bool] = None,
) -> BackendConfig:
@ -1904,9 +1897,7 @@ class Onefuzz:
self.logger.debug("set config")
if reset:
self._backend.config = BackendConfig(
authority="", client_id="", tenant_domain=""
)
self._backend.config = BackendConfig(endpoint="")
if endpoint is not None:
# The normal path for calling the API always uses the oauth2 workflow,
@ -1922,17 +1913,12 @@ class Onefuzz:
"Missing HTTP Authentication"
)
self._backend.config.endpoint = endpoint
if client_id is not None:
self._backend.config.client_id = client_id
if override_authority is not None:
self._backend.config.authority = override_authority
if enable_feature:
self._backend.enable_feature(enable_feature.name)
if override_tenant_domain is not None:
self._backend.config.tenant_domain = override_tenant_domain
self._backend.app = None
self._backend.save_config()
data = self._backend.config.copy(deep=True)
if not data.endpoint:

View File

@ -12,7 +12,6 @@ import sys
import tempfile
import time
from dataclasses import asdict, is_dataclass
from datetime import datetime, timedelta
from enum import Enum
from typing import (
Any,
@ -33,7 +32,7 @@ import msal
import requests
from azure.storage.blob import ContainerClient
from onefuzztypes import responses
from pydantic import BaseModel, Field
from pydantic import BaseModel
from requests import Response
from tenacity import RetryCallState, retry
from tenacity.retry import retry_if_exception_type
@ -93,20 +92,26 @@ def check_application_error(response: requests.Response) -> None:
class BackendConfig(BaseModel):
authority: str
client_id: str
endpoint: Optional[str]
features: Set[str] = Field(default_factory=set)
tenant_domain: str
expires_on: datetime = datetime.utcnow() + timedelta(hours=24)
authority: Optional[str]
client_id: Optional[str]
endpoint: str
features: Optional[Set[str]]
tenant_domain: Optional[str]
def get_multi_tenant_domain(self) -> Optional[str]:
if "https://login.microsoftonline.com/common" in self.authority:
if (
self.authority
and "https://login.microsoftonline.com/common" in self.authority
):
return self.tenant_domain
else:
return None
class CacheConfig(BaseModel):
endpoint: Optional[str]
class Backend:
def __init__(
self,
@ -129,10 +134,14 @@ class Backend:
atexit.register(self.save_cache)
def enable_feature(self, name: str) -> None:
if not self.config.features:
self.config.features = Set[str]()
self.config.features.add(name)
def is_feature_enabled(self, name: str) -> bool:
if self.config.features:
return name in self.config.features
return False
def load_config(self) -> None:
if os.path.exists(self.config_path):
@ -143,7 +152,8 @@ class Backend:
def save_config(self) -> None:
os.makedirs(os.path.dirname(self.config_path), exist_ok=True)
with open(self.config_path, "w") as handle:
handle.write(self.config.json(indent=4, exclude_none=True))
endpoint_cache = {"endpoint": f"{self.config.endpoint}"}
handle.write(json.dumps(endpoint_cache, indent=4, sort_keys=True))
def init_cache(self) -> None:
# Ensure the token_path directory exists
@ -331,15 +341,13 @@ class Backend:
endpoint_params = responses.Config.parse_obj(response.json())
# Will override values in storage w/ provided values for SP use
if self.config.client_id == "":
if not self.config.client_id:
self.config.client_id = endpoint_params.client_id
if self.config.authority == "":
if not self.config.authority:
self.config.authority = endpoint_params.authority
if self.config.tenant_domain == "":
if not self.config.tenant_domain:
self.config.tenant_domain = endpoint_params.tenant_domain
self.save_config()
def request(
self,
method: str,
@ -353,17 +361,9 @@ class Backend:
if not endpoint:
raise Exception("endpoint not configured")
# If file expires, remove and force user to reset
if datetime.utcnow() > self.config.expires_on:
os.remove(self.config_path)
self.config = BackendConfig(
endpoint=endpoint, authority="", client_id="", tenant_domain=""
)
url = endpoint + "/api/" + path
if self.config.client_id == "" or (
self.config.authority == "" and self.config.tenant_domain == ""
if not self.config.client_id or (
not self.config.authority and not self.config.tenant_domain
):
self.config_params()
headers = self.headers()