use the primitive types in more places (#514)

This commit is contained in:
bmc-msft
2021-02-05 13:10:37 -05:00
committed by GitHub
parent 51f4eea069
commit 1d74379a70
17 changed files with 74 additions and 61 deletions

View File

@ -16,7 +16,7 @@ from azure.mgmt.subscription import SubscriptionClient
from memoization import cached from memoization import cached
from msrestazure.azure_active_directory import MSIAuthentication from msrestazure.azure_active_directory import MSIAuthentication
from msrestazure.tools import parse_resource_id from msrestazure.tools import parse_resource_id
from onefuzztypes.primitives import Container from onefuzztypes.primitives import Container, Region
from .monkeypatch import allow_more_workers, reduce_logging from .monkeypatch import allow_more_workers, reduce_logging
@ -41,12 +41,12 @@ def get_base_resource_group() -> Any: # should be str
@cached @cached
def get_base_region() -> Any: # should be str def get_base_region() -> Region:
client = ResourceManagementClient( client = ResourceManagementClient(
credential=get_identity(), subscription_id=get_subscription() credential=get_identity(), subscription_id=get_subscription()
) )
group = client.resource_groups.get(get_base_resource_group()) group = client.resource_groups.get(get_base_resource_group())
return group.location return Region(group.location)
@cached @cached
@ -89,11 +89,11 @@ DAY_IN_SECONDS = 60 * 60 * 24
@cached(ttl=DAY_IN_SECONDS) @cached(ttl=DAY_IN_SECONDS)
def get_regions() -> List[str]: def get_regions() -> List[Region]:
subscription = get_subscription() subscription = get_subscription()
client = SubscriptionClient(credential=get_identity()) client = SubscriptionClient(credential=get_identity())
locations = client.subscriptions.list_locations(subscription) locations = client.subscriptions.list_locations(subscription)
return sorted([x.name for x in locations]) return sorted([Region(x.name) for x in locations])
@cached @cached

View File

@ -49,7 +49,7 @@ def check_container(
compare: Compare, compare: Compare,
expected: int, expected: int,
container_type: ContainerType, container_type: ContainerType,
containers: Dict[ContainerType, List[str]], containers: Dict[ContainerType, List[Container]],
) -> None: ) -> None:
actual = len(containers.get(container_type, [])) actual = len(containers.get(container_type, []))
if not check_val(compare, expected, actual): if not check_val(compare, expected, actual):
@ -62,7 +62,7 @@ def check_container(
def check_containers(definition: TaskDefinition, config: TaskConfig) -> None: def check_containers(definition: TaskDefinition, config: TaskConfig) -> None:
checked = set() checked = set()
containers: Dict[ContainerType, List[str]] = {} containers: Dict[ContainerType, List[Container]] = {}
for container in config.containers: for container in config.containers:
if container.name not in checked: if container.name not in checked:
if not container_exists(container.name, StorageType.corpus): if not container_exists(container.name, StorageType.corpus):

View File

@ -18,6 +18,7 @@ from onefuzztypes.events import (
from onefuzztypes.models import Error from onefuzztypes.models import Error
from onefuzztypes.models import Task as BASE_TASK from onefuzztypes.models import Task as BASE_TASK
from onefuzztypes.models import TaskConfig, TaskVm, UserInfo from onefuzztypes.models import TaskConfig, TaskVm, UserInfo
from onefuzztypes.primitives import PoolName
from ..azure.image import get_os from ..azure.image import get_os
from ..azure.queue import create_queue, delete_queue from ..azure.queue import create_queue, delete_queue
@ -165,7 +166,7 @@ class Task(BASE_TASK, ORMMixin):
return task return task
@classmethod @classmethod
def get_tasks_by_pool_name(cls, pool_name: str) -> List["Task"]: def get_tasks_by_pool_name(cls, pool_name: PoolName) -> List["Task"]:
tasks = cls.search_states(states=TaskState.available()) tasks = cls.search_states(states=TaskState.available())
if not tasks: if not tasks:
return [] return []

View File

@ -72,7 +72,7 @@ class Node(BASE_NODE, ORMMixin):
*, *,
scaleset_id: Optional[UUID] = None, scaleset_id: Optional[UUID] = None,
states: Optional[List[NodeState]] = None, states: Optional[List[NodeState]] = None,
pool_name: Optional[str] = None, pool_name: Optional[PoolName] = None,
) -> List["Node"]: ) -> List["Node"]:
query: QueryFilter = {} query: QueryFilter = {}
if scaleset_id: if scaleset_id:
@ -89,7 +89,7 @@ class Node(BASE_NODE, ORMMixin):
*, *,
scaleset_id: Optional[UUID] = None, scaleset_id: Optional[UUID] = None,
states: Optional[List[NodeState]] = None, states: Optional[List[NodeState]] = None,
pool_name: Optional[str] = None, pool_name: Optional[PoolName] = None,
exclude_update_scheduled: bool = False, exclude_update_scheduled: bool = False,
num_results: Optional[int] = None, num_results: Optional[int] = None,
) -> List["Node"]: ) -> List["Node"]:

View File

@ -12,6 +12,7 @@ from subprocess import PIPE, CalledProcessError, check_call # nosec
from typing import List, Optional from typing import List, Optional
from onefuzztypes.models import NotificationConfig from onefuzztypes.models import NotificationConfig
from onefuzztypes.primitives import PoolName
from onefuzz.api import Command, Onefuzz from onefuzz.api import Command, Onefuzz
from onefuzz.cli import execute_api from onefuzz.cli import execute_api
@ -42,7 +43,7 @@ class Ossfuzz(Command):
self, self,
project: str, project: str,
build: str, build: str,
pool: str, pool: PoolName,
sanitizers: Optional[List[str]] = None, sanitizers: Optional[List[str]] = None,
notification_config: Optional[NotificationConfig] = None, notification_config: Optional[NotificationConfig] = None,
) -> None: ) -> None:

View File

@ -170,23 +170,23 @@ class Files(Endpoint):
endpoint = "files" endpoint = "files"
@cached(ttl=ONE_HOUR_IN_SECONDS) @cached(ttl=ONE_HOUR_IN_SECONDS)
def _get_client(self, container: str) -> ContainerWrapper: def _get_client(self, container: primitives.Container) -> ContainerWrapper:
sas = self.onefuzz.containers.get(container).sas_url sas = self.onefuzz.containers.get(container).sas_url
return ContainerWrapper(sas) return ContainerWrapper(sas)
def list(self, container: str) -> models.Files: def list(self, container: primitives.Container) -> models.Files:
""" Get a list of files in a container """ """ Get a list of files in a container """
self.logger.debug("listing files in container: %s", container) self.logger.debug("listing files in container: %s", container)
client = self._get_client(container) client = self._get_client(container)
return models.Files(files=client.list_blobs()) return models.Files(files=client.list_blobs())
def delete(self, container: str, filename: str) -> None: def delete(self, container: primitives.Container, filename: str) -> None:
""" delete a file from a container """ """ delete a file from a container """
self.logger.debug("deleting in container: %s:%s", container, filename) self.logger.debug("deleting in container: %s:%s", container, filename)
client = self._get_client(container) client = self._get_client(container)
client.delete_blob(filename) client.delete_blob(filename)
def get(self, container: str, filename: str) -> bytes: def get(self, container: primitives.Container, filename: str) -> bytes:
""" get a file from a container """ """ get a file from a container """
self.logger.debug("getting file from container: %s:%s", container, filename) self.logger.debug("getting file from container: %s:%s", container, filename)
client = self._get_client(container) client = self._get_client(container)
@ -194,7 +194,10 @@ class Files(Endpoint):
return downloaded return downloaded
def upload_file( def upload_file(
self, container: str, file_path: str, blob_name: Optional[str] = None self,
container: primitives.Container,
file_path: str,
blob_name: Optional[str] = None,
) -> None: ) -> None:
""" uploads a file to a container """ """ uploads a file to a container """
if not blob_name: if not blob_name:
@ -212,7 +215,7 @@ class Files(Endpoint):
client = self._get_client(container) client = self._get_client(container)
client.upload_file(file_path, blob_name) client.upload_file(file_path, blob_name)
def upload_dir(self, container: str, dir_path: str) -> None: def upload_dir(self, container: primitives.Container, dir_path: str) -> None:
""" uploads a directory to a container """ """ uploads a directory to a container """
self.logger.debug("uploading directory to container %s:%s", container, dir_path) self.logger.debug("uploading directory to container %s:%s", container, dir_path)
@ -476,7 +479,9 @@ class Repro(Endpoint):
"GET", models.Repro, data=requests.ReproGet(vm_id=vm_id_expanded) "GET", models.Repro, data=requests.ReproGet(vm_id=vm_id_expanded)
) )
def create(self, container: str, path: str, duration: int = 24) -> models.Repro: def create(
self, container: primitives.Container, path: str, duration: int = 24
) -> models.Repro:
""" Create a Reproduction VM from a Crash Report """ """ Create a Reproduction VM from a Crash Report """
self.logger.info( self.logger.info(
"creating repro vm: %s %s (%d hours)", container, path, duration "creating repro vm: %s %s (%d hours)", container, path, duration
@ -651,7 +656,7 @@ class Repro(Endpoint):
def create_and_connect( def create_and_connect(
self, self,
container: str, container: primitives.Container,
path: str, path: str,
duration: int = 24, duration: int = 24,
delete_after_use: bool = False, delete_after_use: bool = False,
@ -670,14 +675,16 @@ class Notifications(Endpoint):
endpoint = "notifications" endpoint = "notifications"
def create( def create(
self, container: str, config: models.NotificationConfig self, container: primitives.Container, config: models.NotificationConfig
) -> models.Notification: ) -> models.Notification:
""" Create a notification based on a config file """ """ Create a notification based on a config file """
config = requests.NotificationCreate(container=container, config=config.config) config = requests.NotificationCreate(container=container, config=config.config)
return self._req_model("POST", models.Notification, data=config) return self._req_model("POST", models.Notification, data=config)
def create_teams(self, container: str, url: str) -> models.Notification: def create_teams(
self, container: primitives.Container, url: str
) -> models.Notification:
""" Create a Teams notification integration """ """ Create a Teams notification integration """
self.logger.debug("create teams notification integration: %s", container) self.logger.debug("create teams notification integration: %s", container)
@ -687,7 +694,7 @@ class Notifications(Endpoint):
def create_ado( def create_ado(
self, self,
container: str, container: primitives.Container,
project: str, project: str,
base_url: str, base_url: str,
auth_token: str, auth_token: str,
@ -804,7 +811,7 @@ class Tasks(Endpoint):
ensemble_sync_delay: Optional[int] = None, ensemble_sync_delay: Optional[int] = None,
generator_exe: Optional[str] = None, generator_exe: Optional[str] = None,
generator_options: Optional[List[str]] = None, generator_options: Optional[List[str]] = None,
pool_name: str, pool_name: primitives.PoolName,
prereq_tasks: Optional[List[UUID]] = None, prereq_tasks: Optional[List[UUID]] = None,
reboot_after_setup: bool = False, reboot_after_setup: bool = False,
rename_output: bool = False, rename_output: bool = False,
@ -1049,7 +1056,7 @@ class Pool(Endpoint):
), ),
) )
def get_config(self, pool_name: str) -> models.AgentConfig: def get_config(self, pool_name: primitives.PoolName) -> models.AgentConfig:
""" Get the agent configuration for the pool """ """ Get the agent configuration for the pool """
pool = self.get(pool_name) pool = self.get(pool_name)
@ -1168,17 +1175,19 @@ class Node(Endpoint):
*, *,
state: Optional[List[enums.NodeState]] = None, state: Optional[List[enums.NodeState]] = None,
scaleset_id: Optional[UUID_EXPANSION] = None, scaleset_id: Optional[UUID_EXPANSION] = None,
pool_name: Optional[str] = None, pool_name: Optional[primitives.PoolName] = None,
) -> List[models.Node]: ) -> List[models.Node]:
self.logger.debug("list nodes") self.logger.debug("list nodes")
scaleset_id_expanded: Optional[UUID] = None scaleset_id_expanded: Optional[UUID] = None
if pool_name is not None: if pool_name is not None:
pool_name = self._disambiguate( pool_name = primitives.PoolName(
"name", self._disambiguate(
pool_name, "name",
lambda x: False, str(pool_name),
lambda: [x.name for x in self.onefuzz.pools.list()], lambda x: False,
lambda: [x.name for x in self.onefuzz.pools.list()],
)
) )
if scaleset_id is not None: if scaleset_id is not None:
@ -1242,12 +1251,12 @@ class Scaleset(Endpoint):
def create( def create(
self, self,
pool_name: str, pool_name: primitives.PoolName,
size: int, size: int,
*, *,
image: Optional[str] = None, image: Optional[str] = None,
vm_sku: Optional[str] = "Standard_D2s_v3", vm_sku: Optional[str] = "Standard_D2s_v3",
region: Optional[str] = None, region: Optional[primitives.Region] = None,
spot_instances: bool = False, spot_instances: bool = False,
tags: Optional[Dict[str, str]] = None, tags: Optional[Dict[str, str]] = None,
) -> models.Scaleset: ) -> models.Scaleset:
@ -1375,7 +1384,7 @@ class ScalesetProxy(Endpoint):
), ),
) )
def reset(self, region: str) -> responses.BoolResult: def reset(self, region: primitives.Region) -> responses.BoolResult:
""" Reset the proxy for an existing region """ """ Reset the proxy for an existing region """
return self._req_model( return self._req_model(

View File

@ -32,7 +32,7 @@ from uuid import UUID
import jmespath import jmespath
from docstring_parser import parse as parse_docstring from docstring_parser import parse as parse_docstring
from msrest.serialization import Model from msrest.serialization import Model
from onefuzztypes.primitives import Container, Directory, File from onefuzztypes.primitives import Container, Directory, File, PoolName, Region
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
LOGGER = logging.getLogger("cli") LOGGER = logging.getLogger("cli")
@ -158,6 +158,8 @@ class Builder:
int: {"type": int}, int: {"type": int},
UUID: {"type": UUID}, UUID: {"type": UUID},
Container: {"type": str}, Container: {"type": str},
Region: {"type": str},
PoolName: {"type": str},
File: {"type": arg_file}, File: {"type": arg_file},
Directory: {"type": arg_dir}, Directory: {"type": arg_dir},
} }

View File

@ -20,7 +20,7 @@ from azure.applicationinsights.models import QueryBody
from azure.common.client_factory import get_azure_cli_credentials from azure.common.client_factory import get_azure_cli_credentials
from onefuzztypes.enums import ContainerType, TaskType from onefuzztypes.enums import ContainerType, TaskType
from onefuzztypes.models import BlobRef, NodeAssignment, Report, Task from onefuzztypes.models import BlobRef, NodeAssignment, Report, Task
from onefuzztypes.primitives import Directory from onefuzztypes.primitives import Container, Directory
from onefuzz.api import UUID_EXPANSION, Command, Onefuzz from onefuzz.api import UUID_EXPANSION, Command, Onefuzz
@ -583,13 +583,13 @@ class DebugNotification(Command):
def _get_container( def _get_container(
self, task: Task, container_type: ContainerType self, task: Task, container_type: ContainerType
) -> Optional[str]: ) -> Optional[Container]:
for container in task.config.containers: for container in task.config.containers:
if container.type == container_type: if container.type == container_type:
return container.name return container.name
return None return None
def _get_storage_account(self, container_name: str) -> str: def _get_storage_account(self, container_name: Container) -> str:
sas_url = self.onefuzz.containers.get(container_name).sas_url sas_url = self.onefuzz.containers.get(container_name).sas_url
_, netloc, _, _, _, _ = urlparse(sas_url) _, netloc, _, _, _, _ = urlparse(sas_url)
return netloc.split(".")[0] return netloc.split(".")[0]

View File

@ -36,7 +36,7 @@ from onefuzztypes.models import (
TaskContainers, TaskContainers,
UserInfo, UserInfo,
) )
from onefuzztypes.primitives import Container from onefuzztypes.primitives import Container, PoolName
from pydantic import BaseModel from pydantic import BaseModel
MESSAGE = Tuple[datetime, EventType, str] MESSAGE = Tuple[datetime, EventType, str]
@ -49,7 +49,7 @@ DAYS = 24 * HOURS
# status-top only representation of a Node # status-top only representation of a Node
class MiniNode(BaseModel): class MiniNode(BaseModel):
machine_id: UUID machine_id: UUID
pool_name: str pool_name: PoolName
state: NodeState state: NodeState

View File

@ -71,7 +71,7 @@ class JobHelper:
self.project = project self.project = project
self.name = name self.name = name
self.build = build self.build = build
self.to_monitor: Dict[str, int] = {} self.to_monitor: Dict[Container, int] = {}
if platform is None: if platform is None:
self.platform = JobHelper.get_platform(target_exe) self.platform = JobHelper.get_platform(target_exe)

View File

@ -7,7 +7,7 @@ from typing import Dict, List, Optional
from onefuzztypes.enums import OS, ContainerType, StatsFormat, TaskDebugFlag, TaskType from onefuzztypes.enums import OS, ContainerType, StatsFormat, TaskDebugFlag, TaskType
from onefuzztypes.models import Job, NotificationConfig from onefuzztypes.models import Job, NotificationConfig
from onefuzztypes.primitives import Container, Directory, File from onefuzztypes.primitives import Container, Directory, File, PoolName
from onefuzz.api import Command from onefuzz.api import Command
@ -23,7 +23,7 @@ class AFL(Command):
name: str, name: str,
build: str, build: str,
*, *,
pool_name: str, pool_name: PoolName,
target_exe: File = File("fuzz.exe"), target_exe: File = File("fuzz.exe"),
setup_dir: Optional[Directory] = None, setup_dir: Optional[Directory] = None,
vm_count: int = 2, vm_count: int = 2,

View File

@ -7,7 +7,7 @@ from typing import Dict, List, Optional
from onefuzztypes.enums import ContainerType, TaskDebugFlag, TaskType from onefuzztypes.enums import ContainerType, TaskDebugFlag, TaskType
from onefuzztypes.models import Job, NotificationConfig from onefuzztypes.models import Job, NotificationConfig
from onefuzztypes.primitives import Container, Directory, File from onefuzztypes.primitives import Container, Directory, File, PoolName
from onefuzz.api import Command from onefuzz.api import Command
@ -35,7 +35,7 @@ class Libfuzzer(Command):
*, *,
job: Job, job: Job,
containers: Dict[ContainerType, Container], containers: Dict[ContainerType, Container],
pool_name: str, pool_name: PoolName,
target_exe: str, target_exe: str,
vm_count: int = 2, vm_count: int = 2,
reboot_after_setup: bool = False, reboot_after_setup: bool = False,
@ -145,7 +145,7 @@ class Libfuzzer(Command):
project: str, project: str,
name: str, name: str,
build: str, build: str,
pool_name: str, pool_name: PoolName,
*, *,
target_exe: File = File("fuzz.exe"), target_exe: File = File("fuzz.exe"),
setup_dir: Optional[Directory] = None, setup_dir: Optional[Directory] = None,
@ -261,7 +261,7 @@ class Libfuzzer(Command):
project: str, project: str,
name: str, name: str,
build: str, build: str,
pool_name: str, pool_name: PoolName,
*, *,
target_exe: File = File("fuzz.exe"), target_exe: File = File("fuzz.exe"),
setup_dir: Optional[Directory] = None, setup_dir: Optional[Directory] = None,

View File

@ -11,7 +11,7 @@ from typing import Dict, List, Optional, Tuple
from onefuzztypes.enums import OS, ContainerType, TaskDebugFlag from onefuzztypes.enums import OS, ContainerType, TaskDebugFlag
from onefuzztypes.models import NotificationConfig from onefuzztypes.models import NotificationConfig
from onefuzztypes.primitives import File from onefuzztypes.primitives import File, PoolName
from onefuzz.api import Command from onefuzz.api import Command
from onefuzz.backend import container_file_path from onefuzz.backend import container_file_path
@ -110,7 +110,7 @@ class OssFuzz(Command):
self, self,
project: str, project: str,
build: str, build: str,
pool_name: str, pool_name: PoolName,
duration: int = 24, duration: int = 24,
tags: Optional[Dict[str, str]] = None, tags: Optional[Dict[str, str]] = None,
dryrun: bool = False, dryrun: bool = False,

View File

@ -7,7 +7,7 @@ from typing import Dict, List, Optional
from onefuzztypes.enums import OS, ContainerType, TaskDebugFlag, TaskType from onefuzztypes.enums import OS, ContainerType, TaskDebugFlag, TaskType
from onefuzztypes.models import Job, NotificationConfig from onefuzztypes.models import Job, NotificationConfig
from onefuzztypes.primitives import Container, Directory, File from onefuzztypes.primitives import Container, Directory, File, PoolName
from onefuzz.api import Command from onefuzz.api import Command
@ -23,7 +23,7 @@ class Radamsa(Command):
name: str, name: str,
build: str, build: str,
*, *,
pool_name: str, pool_name: PoolName,
target_exe: File = File("fuzz.exe"), target_exe: File = File("fuzz.exe"),
setup_dir: Optional[Directory] = None, setup_dir: Optional[Directory] = None,
vm_count: int = 2, vm_count: int = 2,

View File

@ -12,7 +12,7 @@ from pydantic import BaseModel, Extra, Field
from .enums import OS, Architecture, NodeState, TaskState from .enums import OS, Architecture, NodeState, TaskState
from .models import AutoScaleConfig, Error, JobConfig, Report, TaskConfig, UserInfo from .models import AutoScaleConfig, Error, JobConfig, Report, TaskConfig, UserInfo
from .primitives import Container, Region from .primitives import Container, PoolName, Region
from .responses import BaseResponse from .responses import BaseResponse
@ -66,7 +66,7 @@ class EventPing(BaseResponse):
class EventScalesetCreated(BaseEvent): class EventScalesetCreated(BaseEvent):
scaleset_id: UUID scaleset_id: UUID
pool_name: str pool_name: PoolName
vm_sku: str vm_sku: str
image: str image: str
region: Region region: Region
@ -75,21 +75,21 @@ class EventScalesetCreated(BaseEvent):
class EventScalesetFailed(BaseEvent): class EventScalesetFailed(BaseEvent):
scaleset_id: UUID scaleset_id: UUID
pool_name: str pool_name: PoolName
error: Error error: Error
class EventScalesetDeleted(BaseEvent): class EventScalesetDeleted(BaseEvent):
scaleset_id: UUID scaleset_id: UUID
pool_name: str pool_name: PoolName
class EventPoolDeleted(BaseEvent): class EventPoolDeleted(BaseEvent):
pool_name: str pool_name: PoolName
class EventPoolCreated(BaseEvent): class EventPoolCreated(BaseEvent):
pool_name: str pool_name: PoolName
os: OS os: OS
arch: Architecture arch: Architecture
managed: bool managed: bool
@ -112,19 +112,19 @@ class EventProxyFailed(BaseEvent):
class EventNodeCreated(BaseEvent): class EventNodeCreated(BaseEvent):
machine_id: UUID machine_id: UUID
scaleset_id: Optional[UUID] scaleset_id: Optional[UUID]
pool_name: str pool_name: PoolName
class EventNodeDeleted(BaseEvent): class EventNodeDeleted(BaseEvent):
machine_id: UUID machine_id: UUID
scaleset_id: Optional[UUID] scaleset_id: Optional[UUID]
pool_name: str pool_name: PoolName
class EventNodeStateUpdated(BaseEvent): class EventNodeStateUpdated(BaseEvent):
machine_id: UUID machine_id: UUID
scaleset_id: Optional[UUID] scaleset_id: Optional[UUID]
pool_name: str pool_name: PoolName
state: NodeState state: NodeState

View File

@ -334,7 +334,7 @@ class ClientCredentials(BaseModel):
class AgentConfig(BaseModel): class AgentConfig(BaseModel):
client_credentials: Optional[ClientCredentials] client_credentials: Optional[ClientCredentials]
onefuzz_url: str onefuzz_url: str
pool_name: str pool_name: PoolName
heartbeat_queue: Optional[str] heartbeat_queue: Optional[str]
instrumentation_key: Optional[str] instrumentation_key: Optional[str]
telemetry_key: Optional[str] telemetry_key: Optional[str]

View File

@ -135,7 +135,7 @@ class NodeSearch(BaseRequest):
machine_id: Optional[UUID] machine_id: Optional[UUID]
state: Optional[List[NodeState]] state: Optional[List[NodeState]]
scaleset_id: Optional[UUID] scaleset_id: Optional[UUID]
pool_name: Optional[str] pool_name: Optional[PoolName]
class NodeGet(BaseRequest): class NodeGet(BaseRequest):