mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-17 12:28:07 +00:00
Remove Additional config
params - require on each request (#3000)
* Only Overrite Config Cache * Lint * Fixing isort. * Removing expiry. * Removing import. * Removing config params. * Remove bad import. * Adjusting to type changes. * Remove whitespace. * Formatting. * Formatting. * null check. * Formatting.
This commit is contained in:
committed by
GitHub
parent
77c42930a6
commit
169cef7a06
@ -12,8 +12,6 @@ def main(req: func.HttpRequest) -> func.HttpResponse:
|
|||||||
o = Onefuzz()
|
o = Onefuzz()
|
||||||
o.config(
|
o.config(
|
||||||
endpoint=os.environ.get("ONEFUZZ_ENDPOINT"),
|
endpoint=os.environ.get("ONEFUZZ_ENDPOINT"),
|
||||||
override_authority=os.environ.get("ONEFUZZ_AUTHORITY"),
|
|
||||||
client_id=os.environ.get("ONEFUZZ_CLIENT_ID"),
|
|
||||||
)
|
)
|
||||||
info = o.info.get()
|
info = o.info.get()
|
||||||
return func.HttpResponse(info.json())
|
return func.HttpResponse(info.json())
|
||||||
|
@ -40,11 +40,7 @@ from .ssh import build_ssh_command, ssh_connect, temp_file
|
|||||||
|
|
||||||
UUID_EXPANSION = TypeVar("UUID_EXPANSION", UUID, str)
|
UUID_EXPANSION = TypeVar("UUID_EXPANSION", UUID, str)
|
||||||
|
|
||||||
DEFAULT = BackendConfig(
|
DEFAULT = BackendConfig(endpoint="")
|
||||||
authority="",
|
|
||||||
client_id="",
|
|
||||||
tenant_domain="",
|
|
||||||
)
|
|
||||||
|
|
||||||
# This was generated randomly and should be preserved moving forwards
|
# This was generated randomly and should be preserved moving forwards
|
||||||
ONEFUZZ_GUID_NAMESPACE = uuid.UUID("27f25e3f-6544-4b69-b309-9b096c5a9cbc")
|
ONEFUZZ_GUID_NAMESPACE = uuid.UUID("27f25e3f-6544-4b69-b309-9b096c5a9cbc")
|
||||||
@ -1310,7 +1306,7 @@ class Pool(Endpoint):
|
|||||||
raise Exception("Missing AgentConfig in response")
|
raise Exception("Missing AgentConfig in response")
|
||||||
|
|
||||||
config = pool.config
|
config = pool.config
|
||||||
if not pool.managed:
|
if not pool.managed and self.onefuzz._backend.config.authority:
|
||||||
config.client_credentials = models.ClientCredentials( # nosec
|
config.client_credentials = models.ClientCredentials( # nosec
|
||||||
client_id=uuid.UUID(int=0),
|
client_id=uuid.UUID(int=0),
|
||||||
client_secret="<client_secret>",
|
client_secret="<client_secret>",
|
||||||
@ -1894,9 +1890,6 @@ class Onefuzz:
|
|||||||
def config(
|
def config(
|
||||||
self,
|
self,
|
||||||
endpoint: Optional[str] = None,
|
endpoint: Optional[str] = None,
|
||||||
override_authority: Optional[str] = None,
|
|
||||||
client_id: Optional[str] = None,
|
|
||||||
override_tenant_domain: Optional[str] = None,
|
|
||||||
enable_feature: Optional[PreviewFeature] = None,
|
enable_feature: Optional[PreviewFeature] = None,
|
||||||
reset: Optional[bool] = None,
|
reset: Optional[bool] = None,
|
||||||
) -> BackendConfig:
|
) -> BackendConfig:
|
||||||
@ -1904,9 +1897,7 @@ class Onefuzz:
|
|||||||
self.logger.debug("set config")
|
self.logger.debug("set config")
|
||||||
|
|
||||||
if reset:
|
if reset:
|
||||||
self._backend.config = BackendConfig(
|
self._backend.config = BackendConfig(endpoint="")
|
||||||
authority="", client_id="", tenant_domain=""
|
|
||||||
)
|
|
||||||
|
|
||||||
if endpoint is not None:
|
if endpoint is not None:
|
||||||
# The normal path for calling the API always uses the oauth2 workflow,
|
# The normal path for calling the API always uses the oauth2 workflow,
|
||||||
@ -1922,17 +1913,12 @@ class Onefuzz:
|
|||||||
"Missing HTTP Authentication"
|
"Missing HTTP Authentication"
|
||||||
)
|
)
|
||||||
self._backend.config.endpoint = endpoint
|
self._backend.config.endpoint = endpoint
|
||||||
if client_id is not None:
|
|
||||||
self._backend.config.client_id = client_id
|
|
||||||
if override_authority is not None:
|
|
||||||
self._backend.config.authority = override_authority
|
|
||||||
if enable_feature:
|
if enable_feature:
|
||||||
self._backend.enable_feature(enable_feature.name)
|
self._backend.enable_feature(enable_feature.name)
|
||||||
if override_tenant_domain is not None:
|
|
||||||
self._backend.config.tenant_domain = override_tenant_domain
|
|
||||||
self._backend.app = None
|
self._backend.app = None
|
||||||
self._backend.save_config()
|
self._backend.save_config()
|
||||||
|
|
||||||
data = self._backend.config.copy(deep=True)
|
data = self._backend.config.copy(deep=True)
|
||||||
|
|
||||||
if not data.endpoint:
|
if not data.endpoint:
|
||||||
|
@ -12,7 +12,6 @@ import sys
|
|||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
from dataclasses import asdict, is_dataclass
|
from dataclasses import asdict, is_dataclass
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
@ -33,7 +32,7 @@ import msal
|
|||||||
import requests
|
import requests
|
||||||
from azure.storage.blob import ContainerClient
|
from azure.storage.blob import ContainerClient
|
||||||
from onefuzztypes import responses
|
from onefuzztypes import responses
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel
|
||||||
from requests import Response
|
from requests import Response
|
||||||
from tenacity import RetryCallState, retry
|
from tenacity import RetryCallState, retry
|
||||||
from tenacity.retry import retry_if_exception_type
|
from tenacity.retry import retry_if_exception_type
|
||||||
@ -93,20 +92,26 @@ def check_application_error(response: requests.Response) -> None:
|
|||||||
|
|
||||||
|
|
||||||
class BackendConfig(BaseModel):
|
class BackendConfig(BaseModel):
|
||||||
authority: str
|
authority: Optional[str]
|
||||||
client_id: str
|
client_id: Optional[str]
|
||||||
endpoint: Optional[str]
|
endpoint: str
|
||||||
features: Set[str] = Field(default_factory=set)
|
features: Optional[Set[str]]
|
||||||
tenant_domain: str
|
tenant_domain: Optional[str]
|
||||||
expires_on: datetime = datetime.utcnow() + timedelta(hours=24)
|
|
||||||
|
|
||||||
def get_multi_tenant_domain(self) -> Optional[str]:
|
def get_multi_tenant_domain(self) -> Optional[str]:
|
||||||
if "https://login.microsoftonline.com/common" in self.authority:
|
if (
|
||||||
|
self.authority
|
||||||
|
and "https://login.microsoftonline.com/common" in self.authority
|
||||||
|
):
|
||||||
return self.tenant_domain
|
return self.tenant_domain
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class CacheConfig(BaseModel):
|
||||||
|
endpoint: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
class Backend:
|
class Backend:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -129,10 +134,14 @@ class Backend:
|
|||||||
atexit.register(self.save_cache)
|
atexit.register(self.save_cache)
|
||||||
|
|
||||||
def enable_feature(self, name: str) -> None:
|
def enable_feature(self, name: str) -> None:
|
||||||
|
if not self.config.features:
|
||||||
|
self.config.features = Set[str]()
|
||||||
self.config.features.add(name)
|
self.config.features.add(name)
|
||||||
|
|
||||||
def is_feature_enabled(self, name: str) -> bool:
|
def is_feature_enabled(self, name: str) -> bool:
|
||||||
return name in self.config.features
|
if self.config.features:
|
||||||
|
return name in self.config.features
|
||||||
|
return False
|
||||||
|
|
||||||
def load_config(self) -> None:
|
def load_config(self) -> None:
|
||||||
if os.path.exists(self.config_path):
|
if os.path.exists(self.config_path):
|
||||||
@ -143,7 +152,8 @@ class Backend:
|
|||||||
def save_config(self) -> None:
|
def save_config(self) -> None:
|
||||||
os.makedirs(os.path.dirname(self.config_path), exist_ok=True)
|
os.makedirs(os.path.dirname(self.config_path), exist_ok=True)
|
||||||
with open(self.config_path, "w") as handle:
|
with open(self.config_path, "w") as handle:
|
||||||
handle.write(self.config.json(indent=4, exclude_none=True))
|
endpoint_cache = {"endpoint": f"{self.config.endpoint}"}
|
||||||
|
handle.write(json.dumps(endpoint_cache, indent=4, sort_keys=True))
|
||||||
|
|
||||||
def init_cache(self) -> None:
|
def init_cache(self) -> None:
|
||||||
# Ensure the token_path directory exists
|
# Ensure the token_path directory exists
|
||||||
@ -331,15 +341,13 @@ class Backend:
|
|||||||
endpoint_params = responses.Config.parse_obj(response.json())
|
endpoint_params = responses.Config.parse_obj(response.json())
|
||||||
|
|
||||||
# Will override values in storage w/ provided values for SP use
|
# Will override values in storage w/ provided values for SP use
|
||||||
if self.config.client_id == "":
|
if not self.config.client_id:
|
||||||
self.config.client_id = endpoint_params.client_id
|
self.config.client_id = endpoint_params.client_id
|
||||||
if self.config.authority == "":
|
if not self.config.authority:
|
||||||
self.config.authority = endpoint_params.authority
|
self.config.authority = endpoint_params.authority
|
||||||
if self.config.tenant_domain == "":
|
if not self.config.tenant_domain:
|
||||||
self.config.tenant_domain = endpoint_params.tenant_domain
|
self.config.tenant_domain = endpoint_params.tenant_domain
|
||||||
|
|
||||||
self.save_config()
|
|
||||||
|
|
||||||
def request(
|
def request(
|
||||||
self,
|
self,
|
||||||
method: str,
|
method: str,
|
||||||
@ -353,17 +361,9 @@ class Backend:
|
|||||||
if not endpoint:
|
if not endpoint:
|
||||||
raise Exception("endpoint not configured")
|
raise Exception("endpoint not configured")
|
||||||
|
|
||||||
# If file expires, remove and force user to reset
|
|
||||||
if datetime.utcnow() > self.config.expires_on:
|
|
||||||
os.remove(self.config_path)
|
|
||||||
self.config = BackendConfig(
|
|
||||||
endpoint=endpoint, authority="", client_id="", tenant_domain=""
|
|
||||||
)
|
|
||||||
|
|
||||||
url = endpoint + "/api/" + path
|
url = endpoint + "/api/" + path
|
||||||
|
if not self.config.client_id or (
|
||||||
if self.config.client_id == "" or (
|
not self.config.authority and not self.config.tenant_domain
|
||||||
self.config.authority == "" and self.config.tenant_domain == ""
|
|
||||||
):
|
):
|
||||||
self.config_params()
|
self.config_params()
|
||||||
headers = self.headers()
|
headers = self.headers()
|
||||||
|
Reference in New Issue
Block a user