use the primitive types in more places (#514)

This commit is contained in:
bmc-msft
2021-02-05 13:10:37 -05:00
committed by GitHub
parent 51f4eea069
commit 1d74379a70
17 changed files with 74 additions and 61 deletions

View File

@ -16,7 +16,7 @@ from azure.mgmt.subscription import SubscriptionClient
from memoization import cached
from msrestazure.azure_active_directory import MSIAuthentication
from msrestazure.tools import parse_resource_id
from onefuzztypes.primitives import Container
from onefuzztypes.primitives import Container, Region
from .monkeypatch import allow_more_workers, reduce_logging
@ -41,12 +41,12 @@ def get_base_resource_group() -> Any: # should be str
@cached
def get_base_region() -> Any: # should be str
def get_base_region() -> Region:
client = ResourceManagementClient(
credential=get_identity(), subscription_id=get_subscription()
)
group = client.resource_groups.get(get_base_resource_group())
return group.location
return Region(group.location)
@cached
@ -89,11 +89,11 @@ DAY_IN_SECONDS = 60 * 60 * 24
@cached(ttl=DAY_IN_SECONDS)
def get_regions() -> List[str]:
def get_regions() -> List[Region]:
subscription = get_subscription()
client = SubscriptionClient(credential=get_identity())
locations = client.subscriptions.list_locations(subscription)
return sorted([x.name for x in locations])
return sorted([Region(x.name) for x in locations])
@cached

View File

@ -49,7 +49,7 @@ def check_container(
compare: Compare,
expected: int,
container_type: ContainerType,
containers: Dict[ContainerType, List[str]],
containers: Dict[ContainerType, List[Container]],
) -> None:
actual = len(containers.get(container_type, []))
if not check_val(compare, expected, actual):
@ -62,7 +62,7 @@ def check_container(
def check_containers(definition: TaskDefinition, config: TaskConfig) -> None:
checked = set()
containers: Dict[ContainerType, List[str]] = {}
containers: Dict[ContainerType, List[Container]] = {}
for container in config.containers:
if container.name not in checked:
if not container_exists(container.name, StorageType.corpus):

View File

@ -18,6 +18,7 @@ from onefuzztypes.events import (
from onefuzztypes.models import Error
from onefuzztypes.models import Task as BASE_TASK
from onefuzztypes.models import TaskConfig, TaskVm, UserInfo
from onefuzztypes.primitives import PoolName
from ..azure.image import get_os
from ..azure.queue import create_queue, delete_queue
@ -165,7 +166,7 @@ class Task(BASE_TASK, ORMMixin):
return task
@classmethod
def get_tasks_by_pool_name(cls, pool_name: str) -> List["Task"]:
def get_tasks_by_pool_name(cls, pool_name: PoolName) -> List["Task"]:
tasks = cls.search_states(states=TaskState.available())
if not tasks:
return []

View File

@ -72,7 +72,7 @@ class Node(BASE_NODE, ORMMixin):
*,
scaleset_id: Optional[UUID] = None,
states: Optional[List[NodeState]] = None,
pool_name: Optional[str] = None,
pool_name: Optional[PoolName] = None,
) -> List["Node"]:
query: QueryFilter = {}
if scaleset_id:
@ -89,7 +89,7 @@ class Node(BASE_NODE, ORMMixin):
*,
scaleset_id: Optional[UUID] = None,
states: Optional[List[NodeState]] = None,
pool_name: Optional[str] = None,
pool_name: Optional[PoolName] = None,
exclude_update_scheduled: bool = False,
num_results: Optional[int] = None,
) -> List["Node"]:

View File

@ -12,6 +12,7 @@ from subprocess import PIPE, CalledProcessError, check_call # nosec
from typing import List, Optional
from onefuzztypes.models import NotificationConfig
from onefuzztypes.primitives import PoolName
from onefuzz.api import Command, Onefuzz
from onefuzz.cli import execute_api
@ -42,7 +43,7 @@ class Ossfuzz(Command):
self,
project: str,
build: str,
pool: str,
pool: PoolName,
sanitizers: Optional[List[str]] = None,
notification_config: Optional[NotificationConfig] = None,
) -> None:

View File

@ -170,23 +170,23 @@ class Files(Endpoint):
endpoint = "files"
@cached(ttl=ONE_HOUR_IN_SECONDS)
def _get_client(self, container: str) -> ContainerWrapper:
def _get_client(self, container: primitives.Container) -> ContainerWrapper:
sas = self.onefuzz.containers.get(container).sas_url
return ContainerWrapper(sas)
def list(self, container: str) -> models.Files:
def list(self, container: primitives.Container) -> models.Files:
""" Get a list of files in a container """
self.logger.debug("listing files in container: %s", container)
client = self._get_client(container)
return models.Files(files=client.list_blobs())
def delete(self, container: str, filename: str) -> None:
def delete(self, container: primitives.Container, filename: str) -> None:
""" delete a file from a container """
self.logger.debug("deleting in container: %s:%s", container, filename)
client = self._get_client(container)
client.delete_blob(filename)
def get(self, container: str, filename: str) -> bytes:
def get(self, container: primitives.Container, filename: str) -> bytes:
""" get a file from a container """
self.logger.debug("getting file from container: %s:%s", container, filename)
client = self._get_client(container)
@ -194,7 +194,10 @@ class Files(Endpoint):
return downloaded
def upload_file(
self, container: str, file_path: str, blob_name: Optional[str] = None
self,
container: primitives.Container,
file_path: str,
blob_name: Optional[str] = None,
) -> None:
""" uploads a file to a container """
if not blob_name:
@ -212,7 +215,7 @@ class Files(Endpoint):
client = self._get_client(container)
client.upload_file(file_path, blob_name)
def upload_dir(self, container: str, dir_path: str) -> None:
def upload_dir(self, container: primitives.Container, dir_path: str) -> None:
""" uploads a directory to a container """
self.logger.debug("uploading directory to container %s:%s", container, dir_path)
@ -476,7 +479,9 @@ class Repro(Endpoint):
"GET", models.Repro, data=requests.ReproGet(vm_id=vm_id_expanded)
)
def create(self, container: str, path: str, duration: int = 24) -> models.Repro:
def create(
self, container: primitives.Container, path: str, duration: int = 24
) -> models.Repro:
""" Create a Reproduction VM from a Crash Report """
self.logger.info(
"creating repro vm: %s %s (%d hours)", container, path, duration
@ -651,7 +656,7 @@ class Repro(Endpoint):
def create_and_connect(
self,
container: str,
container: primitives.Container,
path: str,
duration: int = 24,
delete_after_use: bool = False,
@ -670,14 +675,16 @@ class Notifications(Endpoint):
endpoint = "notifications"
def create(
self, container: str, config: models.NotificationConfig
self, container: primitives.Container, config: models.NotificationConfig
) -> models.Notification:
""" Create a notification based on a config file """
config = requests.NotificationCreate(container=container, config=config.config)
return self._req_model("POST", models.Notification, data=config)
def create_teams(self, container: str, url: str) -> models.Notification:
def create_teams(
self, container: primitives.Container, url: str
) -> models.Notification:
""" Create a Teams notification integration """
self.logger.debug("create teams notification integration: %s", container)
@ -687,7 +694,7 @@ class Notifications(Endpoint):
def create_ado(
self,
container: str,
container: primitives.Container,
project: str,
base_url: str,
auth_token: str,
@ -804,7 +811,7 @@ class Tasks(Endpoint):
ensemble_sync_delay: Optional[int] = None,
generator_exe: Optional[str] = None,
generator_options: Optional[List[str]] = None,
pool_name: str,
pool_name: primitives.PoolName,
prereq_tasks: Optional[List[UUID]] = None,
reboot_after_setup: bool = False,
rename_output: bool = False,
@ -1049,7 +1056,7 @@ class Pool(Endpoint):
),
)
def get_config(self, pool_name: str) -> models.AgentConfig:
def get_config(self, pool_name: primitives.PoolName) -> models.AgentConfig:
""" Get the agent configuration for the pool """
pool = self.get(pool_name)
@ -1168,17 +1175,19 @@ class Node(Endpoint):
*,
state: Optional[List[enums.NodeState]] = None,
scaleset_id: Optional[UUID_EXPANSION] = None,
pool_name: Optional[str] = None,
pool_name: Optional[primitives.PoolName] = None,
) -> List[models.Node]:
self.logger.debug("list nodes")
scaleset_id_expanded: Optional[UUID] = None
if pool_name is not None:
pool_name = self._disambiguate(
"name",
pool_name,
lambda x: False,
lambda: [x.name for x in self.onefuzz.pools.list()],
pool_name = primitives.PoolName(
self._disambiguate(
"name",
str(pool_name),
lambda x: False,
lambda: [x.name for x in self.onefuzz.pools.list()],
)
)
if scaleset_id is not None:
@ -1242,12 +1251,12 @@ class Scaleset(Endpoint):
def create(
self,
pool_name: str,
pool_name: primitives.PoolName,
size: int,
*,
image: Optional[str] = None,
vm_sku: Optional[str] = "Standard_D2s_v3",
region: Optional[str] = None,
region: Optional[primitives.Region] = None,
spot_instances: bool = False,
tags: Optional[Dict[str, str]] = None,
) -> models.Scaleset:
@ -1375,7 +1384,7 @@ class ScalesetProxy(Endpoint):
),
)
def reset(self, region: str) -> responses.BoolResult:
def reset(self, region: primitives.Region) -> responses.BoolResult:
""" Reset the proxy for an existing region """
return self._req_model(

View File

@ -32,7 +32,7 @@ from uuid import UUID
import jmespath
from docstring_parser import parse as parse_docstring
from msrest.serialization import Model
from onefuzztypes.primitives import Container, Directory, File
from onefuzztypes.primitives import Container, Directory, File, PoolName, Region
from pydantic import BaseModel, ValidationError
LOGGER = logging.getLogger("cli")
@ -158,6 +158,8 @@ class Builder:
int: {"type": int},
UUID: {"type": UUID},
Container: {"type": str},
Region: {"type": str},
PoolName: {"type": str},
File: {"type": arg_file},
Directory: {"type": arg_dir},
}

View File

@ -20,7 +20,7 @@ from azure.applicationinsights.models import QueryBody
from azure.common.client_factory import get_azure_cli_credentials
from onefuzztypes.enums import ContainerType, TaskType
from onefuzztypes.models import BlobRef, NodeAssignment, Report, Task
from onefuzztypes.primitives import Directory
from onefuzztypes.primitives import Container, Directory
from onefuzz.api import UUID_EXPANSION, Command, Onefuzz
@ -583,13 +583,13 @@ class DebugNotification(Command):
def _get_container(
self, task: Task, container_type: ContainerType
) -> Optional[str]:
) -> Optional[Container]:
for container in task.config.containers:
if container.type == container_type:
return container.name
return None
def _get_storage_account(self, container_name: str) -> str:
def _get_storage_account(self, container_name: Container) -> str:
sas_url = self.onefuzz.containers.get(container_name).sas_url
_, netloc, _, _, _, _ = urlparse(sas_url)
return netloc.split(".")[0]

View File

@ -36,7 +36,7 @@ from onefuzztypes.models import (
TaskContainers,
UserInfo,
)
from onefuzztypes.primitives import Container
from onefuzztypes.primitives import Container, PoolName
from pydantic import BaseModel
MESSAGE = Tuple[datetime, EventType, str]
@ -49,7 +49,7 @@ DAYS = 24 * HOURS
# status-top only representation of a Node
class MiniNode(BaseModel):
machine_id: UUID
pool_name: str
pool_name: PoolName
state: NodeState

View File

@ -71,7 +71,7 @@ class JobHelper:
self.project = project
self.name = name
self.build = build
self.to_monitor: Dict[str, int] = {}
self.to_monitor: Dict[Container, int] = {}
if platform is None:
self.platform = JobHelper.get_platform(target_exe)

View File

@ -7,7 +7,7 @@ from typing import Dict, List, Optional
from onefuzztypes.enums import OS, ContainerType, StatsFormat, TaskDebugFlag, TaskType
from onefuzztypes.models import Job, NotificationConfig
from onefuzztypes.primitives import Container, Directory, File
from onefuzztypes.primitives import Container, Directory, File, PoolName
from onefuzz.api import Command
@ -23,7 +23,7 @@ class AFL(Command):
name: str,
build: str,
*,
pool_name: str,
pool_name: PoolName,
target_exe: File = File("fuzz.exe"),
setup_dir: Optional[Directory] = None,
vm_count: int = 2,

View File

@ -7,7 +7,7 @@ from typing import Dict, List, Optional
from onefuzztypes.enums import ContainerType, TaskDebugFlag, TaskType
from onefuzztypes.models import Job, NotificationConfig
from onefuzztypes.primitives import Container, Directory, File
from onefuzztypes.primitives import Container, Directory, File, PoolName
from onefuzz.api import Command
@ -35,7 +35,7 @@ class Libfuzzer(Command):
*,
job: Job,
containers: Dict[ContainerType, Container],
pool_name: str,
pool_name: PoolName,
target_exe: str,
vm_count: int = 2,
reboot_after_setup: bool = False,
@ -145,7 +145,7 @@ class Libfuzzer(Command):
project: str,
name: str,
build: str,
pool_name: str,
pool_name: PoolName,
*,
target_exe: File = File("fuzz.exe"),
setup_dir: Optional[Directory] = None,
@ -261,7 +261,7 @@ class Libfuzzer(Command):
project: str,
name: str,
build: str,
pool_name: str,
pool_name: PoolName,
*,
target_exe: File = File("fuzz.exe"),
setup_dir: Optional[Directory] = None,

View File

@ -11,7 +11,7 @@ from typing import Dict, List, Optional, Tuple
from onefuzztypes.enums import OS, ContainerType, TaskDebugFlag
from onefuzztypes.models import NotificationConfig
from onefuzztypes.primitives import File
from onefuzztypes.primitives import File, PoolName
from onefuzz.api import Command
from onefuzz.backend import container_file_path
@ -110,7 +110,7 @@ class OssFuzz(Command):
self,
project: str,
build: str,
pool_name: str,
pool_name: PoolName,
duration: int = 24,
tags: Optional[Dict[str, str]] = None,
dryrun: bool = False,

View File

@ -7,7 +7,7 @@ from typing import Dict, List, Optional
from onefuzztypes.enums import OS, ContainerType, TaskDebugFlag, TaskType
from onefuzztypes.models import Job, NotificationConfig
from onefuzztypes.primitives import Container, Directory, File
from onefuzztypes.primitives import Container, Directory, File, PoolName
from onefuzz.api import Command
@ -23,7 +23,7 @@ class Radamsa(Command):
name: str,
build: str,
*,
pool_name: str,
pool_name: PoolName,
target_exe: File = File("fuzz.exe"),
setup_dir: Optional[Directory] = None,
vm_count: int = 2,

View File

@ -12,7 +12,7 @@ from pydantic import BaseModel, Extra, Field
from .enums import OS, Architecture, NodeState, TaskState
from .models import AutoScaleConfig, Error, JobConfig, Report, TaskConfig, UserInfo
from .primitives import Container, Region
from .primitives import Container, PoolName, Region
from .responses import BaseResponse
@ -66,7 +66,7 @@ class EventPing(BaseResponse):
class EventScalesetCreated(BaseEvent):
scaleset_id: UUID
pool_name: str
pool_name: PoolName
vm_sku: str
image: str
region: Region
@ -75,21 +75,21 @@ class EventScalesetCreated(BaseEvent):
class EventScalesetFailed(BaseEvent):
scaleset_id: UUID
pool_name: str
pool_name: PoolName
error: Error
class EventScalesetDeleted(BaseEvent):
scaleset_id: UUID
pool_name: str
pool_name: PoolName
class EventPoolDeleted(BaseEvent):
pool_name: str
pool_name: PoolName
class EventPoolCreated(BaseEvent):
pool_name: str
pool_name: PoolName
os: OS
arch: Architecture
managed: bool
@ -112,19 +112,19 @@ class EventProxyFailed(BaseEvent):
class EventNodeCreated(BaseEvent):
machine_id: UUID
scaleset_id: Optional[UUID]
pool_name: str
pool_name: PoolName
class EventNodeDeleted(BaseEvent):
machine_id: UUID
scaleset_id: Optional[UUID]
pool_name: str
pool_name: PoolName
class EventNodeStateUpdated(BaseEvent):
machine_id: UUID
scaleset_id: Optional[UUID]
pool_name: str
pool_name: PoolName
state: NodeState

View File

@ -334,7 +334,7 @@ class ClientCredentials(BaseModel):
class AgentConfig(BaseModel):
client_credentials: Optional[ClientCredentials]
onefuzz_url: str
pool_name: str
pool_name: PoolName
heartbeat_queue: Optional[str]
instrumentation_key: Optional[str]
telemetry_key: Optional[str]

View File

@ -135,7 +135,7 @@ class NodeSearch(BaseRequest):
machine_id: Optional[UUID]
state: Optional[List[NodeState]]
scaleset_id: Optional[UUID]
pool_name: Optional[str]
pool_name: Optional[PoolName]
class NodeGet(BaseRequest):