mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-17 04:18:07 +00:00
Remove Additional config
params - require on each request (#3000)
* Only Overrite Config Cache * Lint * Fixing isort. * Removing expiry. * Removing import. * Removing config params. * Remove bad import. * Adjusting to type changes. * Remove whitespace. * Formatting. * Formatting. * null check. * Formatting.
This commit is contained in:
committed by
GitHub
parent
77c42930a6
commit
169cef7a06
@ -12,8 +12,6 @@ def main(req: func.HttpRequest) -> func.HttpResponse:
|
||||
o = Onefuzz()
|
||||
o.config(
|
||||
endpoint=os.environ.get("ONEFUZZ_ENDPOINT"),
|
||||
override_authority=os.environ.get("ONEFUZZ_AUTHORITY"),
|
||||
client_id=os.environ.get("ONEFUZZ_CLIENT_ID"),
|
||||
)
|
||||
info = o.info.get()
|
||||
return func.HttpResponse(info.json())
|
||||
|
@ -40,11 +40,7 @@ from .ssh import build_ssh_command, ssh_connect, temp_file
|
||||
|
||||
UUID_EXPANSION = TypeVar("UUID_EXPANSION", UUID, str)
|
||||
|
||||
DEFAULT = BackendConfig(
|
||||
authority="",
|
||||
client_id="",
|
||||
tenant_domain="",
|
||||
)
|
||||
DEFAULT = BackendConfig(endpoint="")
|
||||
|
||||
# This was generated randomly and should be preserved moving forwards
|
||||
ONEFUZZ_GUID_NAMESPACE = uuid.UUID("27f25e3f-6544-4b69-b309-9b096c5a9cbc")
|
||||
@ -1310,7 +1306,7 @@ class Pool(Endpoint):
|
||||
raise Exception("Missing AgentConfig in response")
|
||||
|
||||
config = pool.config
|
||||
if not pool.managed:
|
||||
if not pool.managed and self.onefuzz._backend.config.authority:
|
||||
config.client_credentials = models.ClientCredentials( # nosec
|
||||
client_id=uuid.UUID(int=0),
|
||||
client_secret="<client_secret>",
|
||||
@ -1894,9 +1890,6 @@ class Onefuzz:
|
||||
def config(
|
||||
self,
|
||||
endpoint: Optional[str] = None,
|
||||
override_authority: Optional[str] = None,
|
||||
client_id: Optional[str] = None,
|
||||
override_tenant_domain: Optional[str] = None,
|
||||
enable_feature: Optional[PreviewFeature] = None,
|
||||
reset: Optional[bool] = None,
|
||||
) -> BackendConfig:
|
||||
@ -1904,9 +1897,7 @@ class Onefuzz:
|
||||
self.logger.debug("set config")
|
||||
|
||||
if reset:
|
||||
self._backend.config = BackendConfig(
|
||||
authority="", client_id="", tenant_domain=""
|
||||
)
|
||||
self._backend.config = BackendConfig(endpoint="")
|
||||
|
||||
if endpoint is not None:
|
||||
# The normal path for calling the API always uses the oauth2 workflow,
|
||||
@ -1922,17 +1913,12 @@ class Onefuzz:
|
||||
"Missing HTTP Authentication"
|
||||
)
|
||||
self._backend.config.endpoint = endpoint
|
||||
if client_id is not None:
|
||||
self._backend.config.client_id = client_id
|
||||
if override_authority is not None:
|
||||
self._backend.config.authority = override_authority
|
||||
|
||||
if enable_feature:
|
||||
self._backend.enable_feature(enable_feature.name)
|
||||
if override_tenant_domain is not None:
|
||||
self._backend.config.tenant_domain = override_tenant_domain
|
||||
|
||||
self._backend.app = None
|
||||
self._backend.save_config()
|
||||
|
||||
data = self._backend.config.copy(deep=True)
|
||||
|
||||
if not data.endpoint:
|
||||
|
@ -12,7 +12,6 @@ import sys
|
||||
import tempfile
|
||||
import time
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
@ -33,7 +32,7 @@ import msal
|
||||
import requests
|
||||
from azure.storage.blob import ContainerClient
|
||||
from onefuzztypes import responses
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
from requests import Response
|
||||
from tenacity import RetryCallState, retry
|
||||
from tenacity.retry import retry_if_exception_type
|
||||
@ -93,20 +92,26 @@ def check_application_error(response: requests.Response) -> None:
|
||||
|
||||
|
||||
class BackendConfig(BaseModel):
|
||||
authority: str
|
||||
client_id: str
|
||||
endpoint: Optional[str]
|
||||
features: Set[str] = Field(default_factory=set)
|
||||
tenant_domain: str
|
||||
expires_on: datetime = datetime.utcnow() + timedelta(hours=24)
|
||||
authority: Optional[str]
|
||||
client_id: Optional[str]
|
||||
endpoint: str
|
||||
features: Optional[Set[str]]
|
||||
tenant_domain: Optional[str]
|
||||
|
||||
def get_multi_tenant_domain(self) -> Optional[str]:
|
||||
if "https://login.microsoftonline.com/common" in self.authority:
|
||||
if (
|
||||
self.authority
|
||||
and "https://login.microsoftonline.com/common" in self.authority
|
||||
):
|
||||
return self.tenant_domain
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class CacheConfig(BaseModel):
|
||||
endpoint: Optional[str]
|
||||
|
||||
|
||||
class Backend:
|
||||
def __init__(
|
||||
self,
|
||||
@ -129,10 +134,14 @@ class Backend:
|
||||
atexit.register(self.save_cache)
|
||||
|
||||
def enable_feature(self, name: str) -> None:
|
||||
if not self.config.features:
|
||||
self.config.features = Set[str]()
|
||||
self.config.features.add(name)
|
||||
|
||||
def is_feature_enabled(self, name: str) -> bool:
|
||||
return name in self.config.features
|
||||
if self.config.features:
|
||||
return name in self.config.features
|
||||
return False
|
||||
|
||||
def load_config(self) -> None:
|
||||
if os.path.exists(self.config_path):
|
||||
@ -143,7 +152,8 @@ class Backend:
|
||||
def save_config(self) -> None:
|
||||
os.makedirs(os.path.dirname(self.config_path), exist_ok=True)
|
||||
with open(self.config_path, "w") as handle:
|
||||
handle.write(self.config.json(indent=4, exclude_none=True))
|
||||
endpoint_cache = {"endpoint": f"{self.config.endpoint}"}
|
||||
handle.write(json.dumps(endpoint_cache, indent=4, sort_keys=True))
|
||||
|
||||
def init_cache(self) -> None:
|
||||
# Ensure the token_path directory exists
|
||||
@ -331,15 +341,13 @@ class Backend:
|
||||
endpoint_params = responses.Config.parse_obj(response.json())
|
||||
|
||||
# Will override values in storage w/ provided values for SP use
|
||||
if self.config.client_id == "":
|
||||
if not self.config.client_id:
|
||||
self.config.client_id = endpoint_params.client_id
|
||||
if self.config.authority == "":
|
||||
if not self.config.authority:
|
||||
self.config.authority = endpoint_params.authority
|
||||
if self.config.tenant_domain == "":
|
||||
if not self.config.tenant_domain:
|
||||
self.config.tenant_domain = endpoint_params.tenant_domain
|
||||
|
||||
self.save_config()
|
||||
|
||||
def request(
|
||||
self,
|
||||
method: str,
|
||||
@ -353,17 +361,9 @@ class Backend:
|
||||
if not endpoint:
|
||||
raise Exception("endpoint not configured")
|
||||
|
||||
# If file expires, remove and force user to reset
|
||||
if datetime.utcnow() > self.config.expires_on:
|
||||
os.remove(self.config_path)
|
||||
self.config = BackendConfig(
|
||||
endpoint=endpoint, authority="", client_id="", tenant_domain=""
|
||||
)
|
||||
|
||||
url = endpoint + "/api/" + path
|
||||
|
||||
if self.config.client_id == "" or (
|
||||
self.config.authority == "" and self.config.tenant_domain == ""
|
||||
if not self.config.client_id or (
|
||||
not self.config.authority and not self.config.tenant_domain
|
||||
):
|
||||
self.config_params()
|
||||
headers = self.headers()
|
||||
|
Reference in New Issue
Block a user