mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-14 19:08:08 +00:00
use the primitive types in more places (#514)
This commit is contained in:
@ -16,7 +16,7 @@ from azure.mgmt.subscription import SubscriptionClient
|
||||
from memoization import cached
|
||||
from msrestazure.azure_active_directory import MSIAuthentication
|
||||
from msrestazure.tools import parse_resource_id
|
||||
from onefuzztypes.primitives import Container
|
||||
from onefuzztypes.primitives import Container, Region
|
||||
|
||||
from .monkeypatch import allow_more_workers, reduce_logging
|
||||
|
||||
@ -41,12 +41,12 @@ def get_base_resource_group() -> Any: # should be str
|
||||
|
||||
|
||||
@cached
|
||||
def get_base_region() -> Any: # should be str
|
||||
def get_base_region() -> Region:
|
||||
client = ResourceManagementClient(
|
||||
credential=get_identity(), subscription_id=get_subscription()
|
||||
)
|
||||
group = client.resource_groups.get(get_base_resource_group())
|
||||
return group.location
|
||||
return Region(group.location)
|
||||
|
||||
|
||||
@cached
|
||||
@ -89,11 +89,11 @@ DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
|
||||
@cached(ttl=DAY_IN_SECONDS)
|
||||
def get_regions() -> List[str]:
|
||||
def get_regions() -> List[Region]:
|
||||
subscription = get_subscription()
|
||||
client = SubscriptionClient(credential=get_identity())
|
||||
locations = client.subscriptions.list_locations(subscription)
|
||||
return sorted([x.name for x in locations])
|
||||
return sorted([Region(x.name) for x in locations])
|
||||
|
||||
|
||||
@cached
|
||||
|
@ -49,7 +49,7 @@ def check_container(
|
||||
compare: Compare,
|
||||
expected: int,
|
||||
container_type: ContainerType,
|
||||
containers: Dict[ContainerType, List[str]],
|
||||
containers: Dict[ContainerType, List[Container]],
|
||||
) -> None:
|
||||
actual = len(containers.get(container_type, []))
|
||||
if not check_val(compare, expected, actual):
|
||||
@ -62,7 +62,7 @@ def check_container(
|
||||
def check_containers(definition: TaskDefinition, config: TaskConfig) -> None:
|
||||
checked = set()
|
||||
|
||||
containers: Dict[ContainerType, List[str]] = {}
|
||||
containers: Dict[ContainerType, List[Container]] = {}
|
||||
for container in config.containers:
|
||||
if container.name not in checked:
|
||||
if not container_exists(container.name, StorageType.corpus):
|
||||
|
@ -18,6 +18,7 @@ from onefuzztypes.events import (
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.models import Task as BASE_TASK
|
||||
from onefuzztypes.models import TaskConfig, TaskVm, UserInfo
|
||||
from onefuzztypes.primitives import PoolName
|
||||
|
||||
from ..azure.image import get_os
|
||||
from ..azure.queue import create_queue, delete_queue
|
||||
@ -165,7 +166,7 @@ class Task(BASE_TASK, ORMMixin):
|
||||
return task
|
||||
|
||||
@classmethod
|
||||
def get_tasks_by_pool_name(cls, pool_name: str) -> List["Task"]:
|
||||
def get_tasks_by_pool_name(cls, pool_name: PoolName) -> List["Task"]:
|
||||
tasks = cls.search_states(states=TaskState.available())
|
||||
if not tasks:
|
||||
return []
|
||||
|
@ -72,7 +72,7 @@ class Node(BASE_NODE, ORMMixin):
|
||||
*,
|
||||
scaleset_id: Optional[UUID] = None,
|
||||
states: Optional[List[NodeState]] = None,
|
||||
pool_name: Optional[str] = None,
|
||||
pool_name: Optional[PoolName] = None,
|
||||
) -> List["Node"]:
|
||||
query: QueryFilter = {}
|
||||
if scaleset_id:
|
||||
@ -89,7 +89,7 @@ class Node(BASE_NODE, ORMMixin):
|
||||
*,
|
||||
scaleset_id: Optional[UUID] = None,
|
||||
states: Optional[List[NodeState]] = None,
|
||||
pool_name: Optional[str] = None,
|
||||
pool_name: Optional[PoolName] = None,
|
||||
exclude_update_scheduled: bool = False,
|
||||
num_results: Optional[int] = None,
|
||||
) -> List["Node"]:
|
||||
|
@ -12,6 +12,7 @@ from subprocess import PIPE, CalledProcessError, check_call # nosec
|
||||
from typing import List, Optional
|
||||
|
||||
from onefuzztypes.models import NotificationConfig
|
||||
from onefuzztypes.primitives import PoolName
|
||||
|
||||
from onefuzz.api import Command, Onefuzz
|
||||
from onefuzz.cli import execute_api
|
||||
@ -42,7 +43,7 @@ class Ossfuzz(Command):
|
||||
self,
|
||||
project: str,
|
||||
build: str,
|
||||
pool: str,
|
||||
pool: PoolName,
|
||||
sanitizers: Optional[List[str]] = None,
|
||||
notification_config: Optional[NotificationConfig] = None,
|
||||
) -> None:
|
||||
|
@ -170,23 +170,23 @@ class Files(Endpoint):
|
||||
endpoint = "files"
|
||||
|
||||
@cached(ttl=ONE_HOUR_IN_SECONDS)
|
||||
def _get_client(self, container: str) -> ContainerWrapper:
|
||||
def _get_client(self, container: primitives.Container) -> ContainerWrapper:
|
||||
sas = self.onefuzz.containers.get(container).sas_url
|
||||
return ContainerWrapper(sas)
|
||||
|
||||
def list(self, container: str) -> models.Files:
|
||||
def list(self, container: primitives.Container) -> models.Files:
|
||||
""" Get a list of files in a container """
|
||||
self.logger.debug("listing files in container: %s", container)
|
||||
client = self._get_client(container)
|
||||
return models.Files(files=client.list_blobs())
|
||||
|
||||
def delete(self, container: str, filename: str) -> None:
|
||||
def delete(self, container: primitives.Container, filename: str) -> None:
|
||||
""" delete a file from a container """
|
||||
self.logger.debug("deleting in container: %s:%s", container, filename)
|
||||
client = self._get_client(container)
|
||||
client.delete_blob(filename)
|
||||
|
||||
def get(self, container: str, filename: str) -> bytes:
|
||||
def get(self, container: primitives.Container, filename: str) -> bytes:
|
||||
""" get a file from a container """
|
||||
self.logger.debug("getting file from container: %s:%s", container, filename)
|
||||
client = self._get_client(container)
|
||||
@ -194,7 +194,10 @@ class Files(Endpoint):
|
||||
return downloaded
|
||||
|
||||
def upload_file(
|
||||
self, container: str, file_path: str, blob_name: Optional[str] = None
|
||||
self,
|
||||
container: primitives.Container,
|
||||
file_path: str,
|
||||
blob_name: Optional[str] = None,
|
||||
) -> None:
|
||||
""" uploads a file to a container """
|
||||
if not blob_name:
|
||||
@ -212,7 +215,7 @@ class Files(Endpoint):
|
||||
client = self._get_client(container)
|
||||
client.upload_file(file_path, blob_name)
|
||||
|
||||
def upload_dir(self, container: str, dir_path: str) -> None:
|
||||
def upload_dir(self, container: primitives.Container, dir_path: str) -> None:
|
||||
""" uploads a directory to a container """
|
||||
|
||||
self.logger.debug("uploading directory to container %s:%s", container, dir_path)
|
||||
@ -476,7 +479,9 @@ class Repro(Endpoint):
|
||||
"GET", models.Repro, data=requests.ReproGet(vm_id=vm_id_expanded)
|
||||
)
|
||||
|
||||
def create(self, container: str, path: str, duration: int = 24) -> models.Repro:
|
||||
def create(
|
||||
self, container: primitives.Container, path: str, duration: int = 24
|
||||
) -> models.Repro:
|
||||
""" Create a Reproduction VM from a Crash Report """
|
||||
self.logger.info(
|
||||
"creating repro vm: %s %s (%d hours)", container, path, duration
|
||||
@ -651,7 +656,7 @@ class Repro(Endpoint):
|
||||
|
||||
def create_and_connect(
|
||||
self,
|
||||
container: str,
|
||||
container: primitives.Container,
|
||||
path: str,
|
||||
duration: int = 24,
|
||||
delete_after_use: bool = False,
|
||||
@ -670,14 +675,16 @@ class Notifications(Endpoint):
|
||||
endpoint = "notifications"
|
||||
|
||||
def create(
|
||||
self, container: str, config: models.NotificationConfig
|
||||
self, container: primitives.Container, config: models.NotificationConfig
|
||||
) -> models.Notification:
|
||||
""" Create a notification based on a config file """
|
||||
|
||||
config = requests.NotificationCreate(container=container, config=config.config)
|
||||
return self._req_model("POST", models.Notification, data=config)
|
||||
|
||||
def create_teams(self, container: str, url: str) -> models.Notification:
|
||||
def create_teams(
|
||||
self, container: primitives.Container, url: str
|
||||
) -> models.Notification:
|
||||
""" Create a Teams notification integration """
|
||||
|
||||
self.logger.debug("create teams notification integration: %s", container)
|
||||
@ -687,7 +694,7 @@ class Notifications(Endpoint):
|
||||
|
||||
def create_ado(
|
||||
self,
|
||||
container: str,
|
||||
container: primitives.Container,
|
||||
project: str,
|
||||
base_url: str,
|
||||
auth_token: str,
|
||||
@ -804,7 +811,7 @@ class Tasks(Endpoint):
|
||||
ensemble_sync_delay: Optional[int] = None,
|
||||
generator_exe: Optional[str] = None,
|
||||
generator_options: Optional[List[str]] = None,
|
||||
pool_name: str,
|
||||
pool_name: primitives.PoolName,
|
||||
prereq_tasks: Optional[List[UUID]] = None,
|
||||
reboot_after_setup: bool = False,
|
||||
rename_output: bool = False,
|
||||
@ -1049,7 +1056,7 @@ class Pool(Endpoint):
|
||||
),
|
||||
)
|
||||
|
||||
def get_config(self, pool_name: str) -> models.AgentConfig:
|
||||
def get_config(self, pool_name: primitives.PoolName) -> models.AgentConfig:
|
||||
""" Get the agent configuration for the pool """
|
||||
|
||||
pool = self.get(pool_name)
|
||||
@ -1168,17 +1175,19 @@ class Node(Endpoint):
|
||||
*,
|
||||
state: Optional[List[enums.NodeState]] = None,
|
||||
scaleset_id: Optional[UUID_EXPANSION] = None,
|
||||
pool_name: Optional[str] = None,
|
||||
pool_name: Optional[primitives.PoolName] = None,
|
||||
) -> List[models.Node]:
|
||||
self.logger.debug("list nodes")
|
||||
scaleset_id_expanded: Optional[UUID] = None
|
||||
|
||||
if pool_name is not None:
|
||||
pool_name = self._disambiguate(
|
||||
"name",
|
||||
pool_name,
|
||||
lambda x: False,
|
||||
lambda: [x.name for x in self.onefuzz.pools.list()],
|
||||
pool_name = primitives.PoolName(
|
||||
self._disambiguate(
|
||||
"name",
|
||||
str(pool_name),
|
||||
lambda x: False,
|
||||
lambda: [x.name for x in self.onefuzz.pools.list()],
|
||||
)
|
||||
)
|
||||
|
||||
if scaleset_id is not None:
|
||||
@ -1242,12 +1251,12 @@ class Scaleset(Endpoint):
|
||||
|
||||
def create(
|
||||
self,
|
||||
pool_name: str,
|
||||
pool_name: primitives.PoolName,
|
||||
size: int,
|
||||
*,
|
||||
image: Optional[str] = None,
|
||||
vm_sku: Optional[str] = "Standard_D2s_v3",
|
||||
region: Optional[str] = None,
|
||||
region: Optional[primitives.Region] = None,
|
||||
spot_instances: bool = False,
|
||||
tags: Optional[Dict[str, str]] = None,
|
||||
) -> models.Scaleset:
|
||||
@ -1375,7 +1384,7 @@ class ScalesetProxy(Endpoint):
|
||||
),
|
||||
)
|
||||
|
||||
def reset(self, region: str) -> responses.BoolResult:
|
||||
def reset(self, region: primitives.Region) -> responses.BoolResult:
|
||||
""" Reset the proxy for an existing region """
|
||||
|
||||
return self._req_model(
|
||||
|
@ -32,7 +32,7 @@ from uuid import UUID
|
||||
import jmespath
|
||||
from docstring_parser import parse as parse_docstring
|
||||
from msrest.serialization import Model
|
||||
from onefuzztypes.primitives import Container, Directory, File
|
||||
from onefuzztypes.primitives import Container, Directory, File, PoolName, Region
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
LOGGER = logging.getLogger("cli")
|
||||
@ -158,6 +158,8 @@ class Builder:
|
||||
int: {"type": int},
|
||||
UUID: {"type": UUID},
|
||||
Container: {"type": str},
|
||||
Region: {"type": str},
|
||||
PoolName: {"type": str},
|
||||
File: {"type": arg_file},
|
||||
Directory: {"type": arg_dir},
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ from azure.applicationinsights.models import QueryBody
|
||||
from azure.common.client_factory import get_azure_cli_credentials
|
||||
from onefuzztypes.enums import ContainerType, TaskType
|
||||
from onefuzztypes.models import BlobRef, NodeAssignment, Report, Task
|
||||
from onefuzztypes.primitives import Directory
|
||||
from onefuzztypes.primitives import Container, Directory
|
||||
|
||||
from onefuzz.api import UUID_EXPANSION, Command, Onefuzz
|
||||
|
||||
@ -583,13 +583,13 @@ class DebugNotification(Command):
|
||||
|
||||
def _get_container(
|
||||
self, task: Task, container_type: ContainerType
|
||||
) -> Optional[str]:
|
||||
) -> Optional[Container]:
|
||||
for container in task.config.containers:
|
||||
if container.type == container_type:
|
||||
return container.name
|
||||
return None
|
||||
|
||||
def _get_storage_account(self, container_name: str) -> str:
|
||||
def _get_storage_account(self, container_name: Container) -> str:
|
||||
sas_url = self.onefuzz.containers.get(container_name).sas_url
|
||||
_, netloc, _, _, _, _ = urlparse(sas_url)
|
||||
return netloc.split(".")[0]
|
||||
|
@ -36,7 +36,7 @@ from onefuzztypes.models import (
|
||||
TaskContainers,
|
||||
UserInfo,
|
||||
)
|
||||
from onefuzztypes.primitives import Container
|
||||
from onefuzztypes.primitives import Container, PoolName
|
||||
from pydantic import BaseModel
|
||||
|
||||
MESSAGE = Tuple[datetime, EventType, str]
|
||||
@ -49,7 +49,7 @@ DAYS = 24 * HOURS
|
||||
# status-top only representation of a Node
|
||||
class MiniNode(BaseModel):
|
||||
machine_id: UUID
|
||||
pool_name: str
|
||||
pool_name: PoolName
|
||||
state: NodeState
|
||||
|
||||
|
||||
|
@ -71,7 +71,7 @@ class JobHelper:
|
||||
self.project = project
|
||||
self.name = name
|
||||
self.build = build
|
||||
self.to_monitor: Dict[str, int] = {}
|
||||
self.to_monitor: Dict[Container, int] = {}
|
||||
|
||||
if platform is None:
|
||||
self.platform = JobHelper.get_platform(target_exe)
|
||||
|
@ -7,7 +7,7 @@ from typing import Dict, List, Optional
|
||||
|
||||
from onefuzztypes.enums import OS, ContainerType, StatsFormat, TaskDebugFlag, TaskType
|
||||
from onefuzztypes.models import Job, NotificationConfig
|
||||
from onefuzztypes.primitives import Container, Directory, File
|
||||
from onefuzztypes.primitives import Container, Directory, File, PoolName
|
||||
|
||||
from onefuzz.api import Command
|
||||
|
||||
@ -23,7 +23,7 @@ class AFL(Command):
|
||||
name: str,
|
||||
build: str,
|
||||
*,
|
||||
pool_name: str,
|
||||
pool_name: PoolName,
|
||||
target_exe: File = File("fuzz.exe"),
|
||||
setup_dir: Optional[Directory] = None,
|
||||
vm_count: int = 2,
|
||||
|
@ -7,7 +7,7 @@ from typing import Dict, List, Optional
|
||||
|
||||
from onefuzztypes.enums import ContainerType, TaskDebugFlag, TaskType
|
||||
from onefuzztypes.models import Job, NotificationConfig
|
||||
from onefuzztypes.primitives import Container, Directory, File
|
||||
from onefuzztypes.primitives import Container, Directory, File, PoolName
|
||||
|
||||
from onefuzz.api import Command
|
||||
|
||||
@ -35,7 +35,7 @@ class Libfuzzer(Command):
|
||||
*,
|
||||
job: Job,
|
||||
containers: Dict[ContainerType, Container],
|
||||
pool_name: str,
|
||||
pool_name: PoolName,
|
||||
target_exe: str,
|
||||
vm_count: int = 2,
|
||||
reboot_after_setup: bool = False,
|
||||
@ -145,7 +145,7 @@ class Libfuzzer(Command):
|
||||
project: str,
|
||||
name: str,
|
||||
build: str,
|
||||
pool_name: str,
|
||||
pool_name: PoolName,
|
||||
*,
|
||||
target_exe: File = File("fuzz.exe"),
|
||||
setup_dir: Optional[Directory] = None,
|
||||
@ -261,7 +261,7 @@ class Libfuzzer(Command):
|
||||
project: str,
|
||||
name: str,
|
||||
build: str,
|
||||
pool_name: str,
|
||||
pool_name: PoolName,
|
||||
*,
|
||||
target_exe: File = File("fuzz.exe"),
|
||||
setup_dir: Optional[Directory] = None,
|
||||
|
@ -11,7 +11,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from onefuzztypes.enums import OS, ContainerType, TaskDebugFlag
|
||||
from onefuzztypes.models import NotificationConfig
|
||||
from onefuzztypes.primitives import File
|
||||
from onefuzztypes.primitives import File, PoolName
|
||||
|
||||
from onefuzz.api import Command
|
||||
from onefuzz.backend import container_file_path
|
||||
@ -110,7 +110,7 @@ class OssFuzz(Command):
|
||||
self,
|
||||
project: str,
|
||||
build: str,
|
||||
pool_name: str,
|
||||
pool_name: PoolName,
|
||||
duration: int = 24,
|
||||
tags: Optional[Dict[str, str]] = None,
|
||||
dryrun: bool = False,
|
||||
|
@ -7,7 +7,7 @@ from typing import Dict, List, Optional
|
||||
|
||||
from onefuzztypes.enums import OS, ContainerType, TaskDebugFlag, TaskType
|
||||
from onefuzztypes.models import Job, NotificationConfig
|
||||
from onefuzztypes.primitives import Container, Directory, File
|
||||
from onefuzztypes.primitives import Container, Directory, File, PoolName
|
||||
|
||||
from onefuzz.api import Command
|
||||
|
||||
@ -23,7 +23,7 @@ class Radamsa(Command):
|
||||
name: str,
|
||||
build: str,
|
||||
*,
|
||||
pool_name: str,
|
||||
pool_name: PoolName,
|
||||
target_exe: File = File("fuzz.exe"),
|
||||
setup_dir: Optional[Directory] = None,
|
||||
vm_count: int = 2,
|
||||
|
@ -12,7 +12,7 @@ from pydantic import BaseModel, Extra, Field
|
||||
|
||||
from .enums import OS, Architecture, NodeState, TaskState
|
||||
from .models import AutoScaleConfig, Error, JobConfig, Report, TaskConfig, UserInfo
|
||||
from .primitives import Container, Region
|
||||
from .primitives import Container, PoolName, Region
|
||||
from .responses import BaseResponse
|
||||
|
||||
|
||||
@ -66,7 +66,7 @@ class EventPing(BaseResponse):
|
||||
|
||||
class EventScalesetCreated(BaseEvent):
|
||||
scaleset_id: UUID
|
||||
pool_name: str
|
||||
pool_name: PoolName
|
||||
vm_sku: str
|
||||
image: str
|
||||
region: Region
|
||||
@ -75,21 +75,21 @@ class EventScalesetCreated(BaseEvent):
|
||||
|
||||
class EventScalesetFailed(BaseEvent):
|
||||
scaleset_id: UUID
|
||||
pool_name: str
|
||||
pool_name: PoolName
|
||||
error: Error
|
||||
|
||||
|
||||
class EventScalesetDeleted(BaseEvent):
|
||||
scaleset_id: UUID
|
||||
pool_name: str
|
||||
pool_name: PoolName
|
||||
|
||||
|
||||
class EventPoolDeleted(BaseEvent):
|
||||
pool_name: str
|
||||
pool_name: PoolName
|
||||
|
||||
|
||||
class EventPoolCreated(BaseEvent):
|
||||
pool_name: str
|
||||
pool_name: PoolName
|
||||
os: OS
|
||||
arch: Architecture
|
||||
managed: bool
|
||||
@ -112,19 +112,19 @@ class EventProxyFailed(BaseEvent):
|
||||
class EventNodeCreated(BaseEvent):
|
||||
machine_id: UUID
|
||||
scaleset_id: Optional[UUID]
|
||||
pool_name: str
|
||||
pool_name: PoolName
|
||||
|
||||
|
||||
class EventNodeDeleted(BaseEvent):
|
||||
machine_id: UUID
|
||||
scaleset_id: Optional[UUID]
|
||||
pool_name: str
|
||||
pool_name: PoolName
|
||||
|
||||
|
||||
class EventNodeStateUpdated(BaseEvent):
|
||||
machine_id: UUID
|
||||
scaleset_id: Optional[UUID]
|
||||
pool_name: str
|
||||
pool_name: PoolName
|
||||
state: NodeState
|
||||
|
||||
|
||||
|
@ -334,7 +334,7 @@ class ClientCredentials(BaseModel):
|
||||
class AgentConfig(BaseModel):
|
||||
client_credentials: Optional[ClientCredentials]
|
||||
onefuzz_url: str
|
||||
pool_name: str
|
||||
pool_name: PoolName
|
||||
heartbeat_queue: Optional[str]
|
||||
instrumentation_key: Optional[str]
|
||||
telemetry_key: Optional[str]
|
||||
|
@ -135,7 +135,7 @@ class NodeSearch(BaseRequest):
|
||||
machine_id: Optional[UUID]
|
||||
state: Optional[List[NodeState]]
|
||||
scaleset_id: Optional[UUID]
|
||||
pool_name: Optional[str]
|
||||
pool_name: Optional[PoolName]
|
||||
|
||||
|
||||
class NodeGet(BaseRequest):
|
||||
|
Reference in New Issue
Block a user