mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-18 04:38:09 +00:00
Add SDK Feature Flags (#313)
## Summary of the Pull Request This enables feature flags for the SDK, which enables gating access to preview features to those that have specifically asked for them. This is intended to be used within #266. Note, this change also moves to using a `pydantic` model for the config, rather than hand-crafted JSON dicts.
This commit is contained in:
@ -9,6 +9,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import subprocess # nosec
|
import subprocess # nosec
|
||||||
import uuid
|
import uuid
|
||||||
|
from enum import Enum
|
||||||
from shutil import which
|
from shutil import which
|
||||||
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, cast
|
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, cast
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
@ -21,15 +22,15 @@ from pydantic import BaseModel
|
|||||||
from six.moves import input # workaround for static analysis
|
from six.moves import input # workaround for static analysis
|
||||||
|
|
||||||
from .__version__ import __version__
|
from .__version__ import __version__
|
||||||
from .backend import Backend, ContainerWrapper, wait
|
from .backend import Backend, BackendConfig, ContainerWrapper, wait
|
||||||
from .ssh import build_ssh_command, ssh_connect, temp_file
|
from .ssh import build_ssh_command, ssh_connect, temp_file
|
||||||
|
|
||||||
UUID_EXPANSION = TypeVar("UUID_EXPANSION", UUID, str)
|
UUID_EXPANSION = TypeVar("UUID_EXPANSION", UUID, str)
|
||||||
|
|
||||||
DEFAULT = {
|
DEFAULT = BackendConfig(
|
||||||
"authority": "https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47",
|
authority="https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47",
|
||||||
"client_id": "72f1562a-8c0c-41ea-beb9-fa2b71c80134",
|
client_id="72f1562a-8c0c-41ea-beb9-fa2b71c80134",
|
||||||
}
|
)
|
||||||
|
|
||||||
# This was generated randomly and should be preserved moving forwards
|
# This was generated randomly and should be preserved moving forwards
|
||||||
ONEFUZZ_GUID_NAMESPACE = uuid.UUID("27f25e3f-6544-4b69-b309-9b096c5a9cbc")
|
ONEFUZZ_GUID_NAMESPACE = uuid.UUID("27f25e3f-6544-4b69-b309-9b096c5a9cbc")
|
||||||
@ -44,6 +45,10 @@ REPRO_SSH_FORWARD = "1337:127.0.0.1:1337"
|
|||||||
UUID_RE = r"^[a-f0-9]{8}-?[a-f0-9]{4}-?[a-f0-9]{4}-?[a-f0-9]{4}-?[a-f0-9]{12}\Z"
|
UUID_RE = r"^[a-f0-9]{8}-?[a-f0-9]{4}-?[a-f0-9]{4}-?[a-f0-9]{4}-?[a-f0-9]{12}\Z"
|
||||||
|
|
||||||
|
|
||||||
|
class PreviewFeature(Enum):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def is_uuid(value: str) -> bool:
|
def is_uuid(value: str) -> bool:
|
||||||
return bool(re.match(UUID_RE, value))
|
return bool(re.match(UUID_RE, value))
|
||||||
|
|
||||||
@ -1449,7 +1454,7 @@ class Onefuzz:
|
|||||||
|
|
||||||
def __setup__(self, endpoint: Optional[str] = None) -> None:
|
def __setup__(self, endpoint: Optional[str] = None) -> None:
|
||||||
if endpoint:
|
if endpoint:
|
||||||
self._backend.config["endpoint"] = endpoint
|
self._backend.config.endpoint = endpoint
|
||||||
|
|
||||||
def licenses(self) -> object:
|
def licenses(self) -> object:
|
||||||
""" Return third-party licenses used by this package """
|
""" Return third-party licenses used by this package """
|
||||||
@ -1476,7 +1481,8 @@ class Onefuzz:
|
|||||||
authority: Optional[str] = None,
|
authority: Optional[str] = None,
|
||||||
client_id: Optional[str] = None,
|
client_id: Optional[str] = None,
|
||||||
client_secret: Optional[str] = None,
|
client_secret: Optional[str] = None,
|
||||||
) -> Dict[str, str]:
|
enable_feature: Optional[PreviewFeature] = None,
|
||||||
|
) -> BackendConfig:
|
||||||
""" Configure onefuzz CLI """
|
""" Configure onefuzz CLI """
|
||||||
self.logger.debug("set config")
|
self.logger.debug("set config")
|
||||||
|
|
||||||
@ -1493,22 +1499,24 @@ class Onefuzz:
|
|||||||
"This could be an invalid OneFuzz API endpoint: "
|
"This could be an invalid OneFuzz API endpoint: "
|
||||||
"Missing HTTP Authentication"
|
"Missing HTTP Authentication"
|
||||||
)
|
)
|
||||||
self._backend.config["endpoint"] = endpoint
|
self._backend.config.endpoint = endpoint
|
||||||
if authority is not None:
|
if authority is not None:
|
||||||
self._backend.config["authority"] = authority
|
self._backend.config.authority = authority
|
||||||
if client_id is not None:
|
if client_id is not None:
|
||||||
self._backend.config["client_id"] = client_id
|
self._backend.config.client_id = client_id
|
||||||
if client_secret is not None:
|
if client_secret is not None:
|
||||||
self._backend.config["client_secret"] = client_secret
|
self._backend.config.client_secret = client_secret
|
||||||
|
if enable_feature:
|
||||||
|
self._backend.enable_feature(enable_feature.name)
|
||||||
self._backend.app = None
|
self._backend.app = None
|
||||||
self._backend.save_config()
|
self._backend.save_config()
|
||||||
|
|
||||||
data: Dict[str, str] = self._backend.config.copy()
|
data = self._backend.config.copy(deep=True)
|
||||||
if "client_secret" in data:
|
if data.client_secret is not None:
|
||||||
# replace existing secrets with "*** for user display
|
# replace existing secrets with "*** for user display
|
||||||
data["client_secret"] = "***" # nosec
|
data.client_secret = "***" # nosec
|
||||||
|
|
||||||
if not data["endpoint"]:
|
if not data.endpoint:
|
||||||
self.logger.warning("endpoint not configured yet")
|
self.logger.warning("endpoint not configured yet")
|
||||||
|
|
||||||
return data
|
return data
|
||||||
@ -1655,6 +1663,12 @@ class Onefuzz:
|
|||||||
webhooks=webhooks,
|
webhooks=webhooks,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _warn_preview(self, feature: PreviewFeature) -> None:
|
||||||
|
self.logger.warning(
|
||||||
|
"%s are a preview-feature and may change in an upcoming release",
|
||||||
|
feature.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
from .debug import Debug # noqa: E402
|
from .debug import Debug # noqa: E402
|
||||||
from .status.cmd import Status # noqa: E402
|
from .status.cmd import Status # noqa: E402
|
||||||
|
@ -12,14 +12,25 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, TypeVar, cast
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Generator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
from urllib.parse import urlparse, urlunparse
|
from urllib.parse import urlparse, urlunparse
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import msal
|
import msal
|
||||||
import requests
|
import requests
|
||||||
from azure.storage.blob import ContainerClient
|
from azure.storage.blob import ContainerClient
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
from tenacity import Future as tenacity_future
|
from tenacity import Future as tenacity_future
|
||||||
from tenacity import Retrying, retry
|
from tenacity import Retrying, retry
|
||||||
from tenacity.retry import retry_if_exception_type
|
from tenacity.retry import retry_if_exception_type
|
||||||
@ -45,16 +56,24 @@ def _temporary_umask(new_umask: int) -> Generator[None, None, None]:
|
|||||||
os.umask(prev_umask)
|
os.umask(prev_umask)
|
||||||
|
|
||||||
|
|
||||||
|
class BackendConfig(BaseModel):
|
||||||
|
authority: str
|
||||||
|
client_id: str
|
||||||
|
client_secret: Optional[str]
|
||||||
|
endpoint: Optional[str]
|
||||||
|
features: Set[str] = Field(default_factory=set)
|
||||||
|
|
||||||
|
|
||||||
class Backend:
|
class Backend:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Optional[Dict[str, str]] = None,
|
config: BackendConfig,
|
||||||
config_path: Optional[str] = None,
|
config_path: Optional[str] = None,
|
||||||
token_path: Optional[str] = None,
|
token_path: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.config_path = os.path.expanduser(config_path or DEFAULT_CONFIG_PATH)
|
self.config_path = os.path.expanduser(config_path or DEFAULT_CONFIG_PATH)
|
||||||
self.token_path = os.path.expanduser(token_path or DEFAULT_TOKEN_PATH)
|
self.token_path = os.path.expanduser(token_path or DEFAULT_TOKEN_PATH)
|
||||||
self.config = config or {}
|
self.config = config
|
||||||
self.token_cache: Optional[msal.SerializableTokenCache] = None
|
self.token_cache: Optional[msal.SerializableTokenCache] = None
|
||||||
self.init_cache()
|
self.init_cache()
|
||||||
self.app: Optional[Any] = None
|
self.app: Optional[Any] = None
|
||||||
@ -64,14 +83,21 @@ class Backend:
|
|||||||
|
|
||||||
atexit.register(self.save_cache)
|
atexit.register(self.save_cache)
|
||||||
|
|
||||||
|
def enable_feature(self, name: str) -> None:
|
||||||
|
self.config.features.add(name)
|
||||||
|
|
||||||
|
def is_feature_enabled(self, name: str) -> bool:
|
||||||
|
return name in self.config.features
|
||||||
|
|
||||||
def load_config(self) -> None:
|
def load_config(self) -> None:
|
||||||
if os.path.exists(self.config_path):
|
if os.path.exists(self.config_path):
|
||||||
with open(self.config_path, "r") as handle:
|
with open(self.config_path, "r") as handle:
|
||||||
self.config.update(json.load(handle))
|
data = json.load(handle)
|
||||||
|
self.config = BackendConfig.parse_obj(data)
|
||||||
|
|
||||||
def save_config(self) -> None:
|
def save_config(self) -> None:
|
||||||
with open(self.config_path, "w") as handle:
|
with open(self.config_path, "w") as handle:
|
||||||
json.dump(self.config, handle)
|
handle.write(self.config.json(indent=4, exclude_none=True))
|
||||||
|
|
||||||
def init_cache(self) -> None:
|
def init_cache(self) -> None:
|
||||||
# Ensure the token_path directory exists
|
# Ensure the token_path directory exists
|
||||||
@ -106,7 +132,7 @@ class Backend:
|
|||||||
|
|
||||||
def headers(self) -> Dict[str, str]:
|
def headers(self) -> Dict[str, str]:
|
||||||
value = {}
|
value = {}
|
||||||
if self.config["client_id"] is not None:
|
if self.config.client_id is not None:
|
||||||
access_token = self.get_access_token()
|
access_token = self.get_access_token()
|
||||||
value["Authorization"] = "%s %s" % (
|
value["Authorization"] = "%s %s" % (
|
||||||
access_token["token_type"],
|
access_token["token_type"],
|
||||||
@ -115,18 +141,21 @@ class Backend:
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
def get_access_token(self) -> Any:
|
def get_access_token(self) -> Any:
|
||||||
scopes = [self.config["endpoint"] + "/.default"]
|
if not self.config.endpoint:
|
||||||
|
raise Exception("endpoint not configured")
|
||||||
|
|
||||||
if "client_secret" in self.config:
|
scopes = [self.config.endpoint + "/.default"]
|
||||||
|
|
||||||
|
if self.config.client_secret:
|
||||||
return self.client_secret(scopes)
|
return self.client_secret(scopes)
|
||||||
return self.device_login(scopes)
|
return self.device_login(scopes)
|
||||||
|
|
||||||
def client_secret(self, scopes: List[str]) -> Any:
|
def client_secret(self, scopes: List[str]) -> Any:
|
||||||
if not self.app:
|
if not self.app:
|
||||||
self.app = msal.ConfidentialClientApplication(
|
self.app = msal.ConfidentialClientApplication(
|
||||||
self.config["client_id"],
|
self.config.client_id,
|
||||||
authority=self.config["authority"],
|
authority=self.config.authority,
|
||||||
client_credential=self.config["client_secret"],
|
client_credential=self.config.client_secret,
|
||||||
token_cache=self.token_cache,
|
token_cache=self.token_cache,
|
||||||
)
|
)
|
||||||
result = self.app.acquire_token_for_client(scopes=scopes)
|
result = self.app.acquire_token_for_client(scopes=scopes)
|
||||||
@ -140,8 +169,8 @@ class Backend:
|
|||||||
def device_login(self, scopes: List[str]) -> Any:
|
def device_login(self, scopes: List[str]) -> Any:
|
||||||
if not self.app:
|
if not self.app:
|
||||||
self.app = msal.PublicClientApplication(
|
self.app = msal.PublicClientApplication(
|
||||||
self.config["client_id"],
|
self.config.client_id,
|
||||||
authority=self.config["authority"],
|
authority=self.config.authority,
|
||||||
token_cache=self.token_cache,
|
token_cache=self.token_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -187,9 +216,9 @@ class Backend:
|
|||||||
params: Optional[Any] = None,
|
params: Optional[Any] = None,
|
||||||
_retry_on_auth_failure: bool = True,
|
_retry_on_auth_failure: bool = True,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if not self.config["endpoint"]:
|
if not self.config.endpoint:
|
||||||
raise Exception("endpoint not configured")
|
raise Exception("endpoint not configured")
|
||||||
url = self.config["endpoint"] + "/api/" + path
|
url = self.config.endpoint + "/api/" + path
|
||||||
headers = self.headers()
|
headers = self.headers()
|
||||||
json_data = serialize(json_data)
|
json_data = serialize(json_data)
|
||||||
|
|
||||||
|
@ -100,7 +100,10 @@ class TopCache:
|
|||||||
self.nodes: Dict[UUID, Tuple[datetime, Node]] = {}
|
self.nodes: Dict[UUID, Tuple[datetime, Node]] = {}
|
||||||
|
|
||||||
self.messages: List[MESSAGE] = []
|
self.messages: List[MESSAGE] = []
|
||||||
self.endpoint: str = onefuzz._backend.config["endpoint"]
|
endpoint = onefuzz._backend.config.endpoint
|
||||||
|
if not endpoint:
|
||||||
|
raise Exception("endpoint is not set")
|
||||||
|
self.endpoint: str = endpoint
|
||||||
self.last_update = datetime.now()
|
self.last_update = datetime.now()
|
||||||
|
|
||||||
def add_container(self, name: str, ignore_date: bool = False) -> None:
|
def add_container(self, name: str, ignore_date: bool = False) -> None:
|
||||||
|
Reference in New Issue
Block a user