use the primitive types in more places (#514)

This commit is contained in:
bmc-msft
2021-02-05 13:10:37 -05:00
committed by GitHub
parent 51f4eea069
commit 1d74379a70
17 changed files with 74 additions and 61 deletions

View File

@ -170,23 +170,23 @@ class Files(Endpoint):
endpoint = "files"
@cached(ttl=ONE_HOUR_IN_SECONDS)
def _get_client(self, container: str) -> ContainerWrapper:
def _get_client(self, container: primitives.Container) -> ContainerWrapper:
sas = self.onefuzz.containers.get(container).sas_url
return ContainerWrapper(sas)
def list(self, container: str) -> models.Files:
def list(self, container: primitives.Container) -> models.Files:
""" Get a list of files in a container """
self.logger.debug("listing files in container: %s", container)
client = self._get_client(container)
return models.Files(files=client.list_blobs())
def delete(self, container: str, filename: str) -> None:
def delete(self, container: primitives.Container, filename: str) -> None:
""" delete a file from a container """
self.logger.debug("deleting in container: %s:%s", container, filename)
client = self._get_client(container)
client.delete_blob(filename)
def get(self, container: str, filename: str) -> bytes:
def get(self, container: primitives.Container, filename: str) -> bytes:
""" get a file from a container """
self.logger.debug("getting file from container: %s:%s", container, filename)
client = self._get_client(container)
@ -194,7 +194,10 @@ class Files(Endpoint):
return downloaded
def upload_file(
self, container: str, file_path: str, blob_name: Optional[str] = None
self,
container: primitives.Container,
file_path: str,
blob_name: Optional[str] = None,
) -> None:
""" uploads a file to a container """
if not blob_name:
@ -212,7 +215,7 @@ class Files(Endpoint):
client = self._get_client(container)
client.upload_file(file_path, blob_name)
def upload_dir(self, container: str, dir_path: str) -> None:
def upload_dir(self, container: primitives.Container, dir_path: str) -> None:
""" uploads a directory to a container """
self.logger.debug("uploading directory to container %s:%s", container, dir_path)
@ -476,7 +479,9 @@ class Repro(Endpoint):
"GET", models.Repro, data=requests.ReproGet(vm_id=vm_id_expanded)
)
def create(self, container: str, path: str, duration: int = 24) -> models.Repro:
def create(
self, container: primitives.Container, path: str, duration: int = 24
) -> models.Repro:
""" Create a Reproduction VM from a Crash Report """
self.logger.info(
"creating repro vm: %s %s (%d hours)", container, path, duration
@ -651,7 +656,7 @@ class Repro(Endpoint):
def create_and_connect(
self,
container: str,
container: primitives.Container,
path: str,
duration: int = 24,
delete_after_use: bool = False,
@ -670,14 +675,16 @@ class Notifications(Endpoint):
endpoint = "notifications"
def create(
self, container: str, config: models.NotificationConfig
self, container: primitives.Container, config: models.NotificationConfig
) -> models.Notification:
""" Create a notification based on a config file """
config = requests.NotificationCreate(container=container, config=config.config)
return self._req_model("POST", models.Notification, data=config)
def create_teams(self, container: str, url: str) -> models.Notification:
def create_teams(
self, container: primitives.Container, url: str
) -> models.Notification:
""" Create a Teams notification integration """
self.logger.debug("create teams notification integration: %s", container)
@ -687,7 +694,7 @@ class Notifications(Endpoint):
def create_ado(
self,
container: str,
container: primitives.Container,
project: str,
base_url: str,
auth_token: str,
@ -804,7 +811,7 @@ class Tasks(Endpoint):
ensemble_sync_delay: Optional[int] = None,
generator_exe: Optional[str] = None,
generator_options: Optional[List[str]] = None,
pool_name: str,
pool_name: primitives.PoolName,
prereq_tasks: Optional[List[UUID]] = None,
reboot_after_setup: bool = False,
rename_output: bool = False,
@ -1049,7 +1056,7 @@ class Pool(Endpoint):
),
)
def get_config(self, pool_name: str) -> models.AgentConfig:
def get_config(self, pool_name: primitives.PoolName) -> models.AgentConfig:
""" Get the agent configuration for the pool """
pool = self.get(pool_name)
@ -1168,17 +1175,19 @@ class Node(Endpoint):
*,
state: Optional[List[enums.NodeState]] = None,
scaleset_id: Optional[UUID_EXPANSION] = None,
pool_name: Optional[str] = None,
pool_name: Optional[primitives.PoolName] = None,
) -> List[models.Node]:
self.logger.debug("list nodes")
scaleset_id_expanded: Optional[UUID] = None
if pool_name is not None:
pool_name = self._disambiguate(
"name",
pool_name,
lambda x: False,
lambda: [x.name for x in self.onefuzz.pools.list()],
pool_name = primitives.PoolName(
self._disambiguate(
"name",
str(pool_name),
lambda x: False,
lambda: [x.name for x in self.onefuzz.pools.list()],
)
)
if scaleset_id is not None:
@ -1242,12 +1251,12 @@ class Scaleset(Endpoint):
def create(
self,
pool_name: str,
pool_name: primitives.PoolName,
size: int,
*,
image: Optional[str] = None,
vm_sku: Optional[str] = "Standard_D2s_v3",
region: Optional[str] = None,
region: Optional[primitives.Region] = None,
spot_instances: bool = False,
tags: Optional[Dict[str, str]] = None,
) -> models.Scaleset:
@ -1375,7 +1384,7 @@ class ScalesetProxy(Endpoint):
),
)
def reset(self, region: str) -> responses.BoolResult:
def reset(self, region: primitives.Region) -> responses.BoolResult:
""" Reset the proxy for an existing region """
return self._req_model(