mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-17 20:38:06 +00:00
use the primitive types in more places (#514)
This commit is contained in:
@ -170,23 +170,23 @@ class Files(Endpoint):
|
||||
endpoint = "files"
|
||||
|
||||
@cached(ttl=ONE_HOUR_IN_SECONDS)
|
||||
def _get_client(self, container: str) -> ContainerWrapper:
|
||||
def _get_client(self, container: primitives.Container) -> ContainerWrapper:
|
||||
sas = self.onefuzz.containers.get(container).sas_url
|
||||
return ContainerWrapper(sas)
|
||||
|
||||
def list(self, container: str) -> models.Files:
|
||||
def list(self, container: primitives.Container) -> models.Files:
|
||||
""" Get a list of files in a container """
|
||||
self.logger.debug("listing files in container: %s", container)
|
||||
client = self._get_client(container)
|
||||
return models.Files(files=client.list_blobs())
|
||||
|
||||
def delete(self, container: str, filename: str) -> None:
|
||||
def delete(self, container: primitives.Container, filename: str) -> None:
|
||||
""" delete a file from a container """
|
||||
self.logger.debug("deleting in container: %s:%s", container, filename)
|
||||
client = self._get_client(container)
|
||||
client.delete_blob(filename)
|
||||
|
||||
def get(self, container: str, filename: str) -> bytes:
|
||||
def get(self, container: primitives.Container, filename: str) -> bytes:
|
||||
""" get a file from a container """
|
||||
self.logger.debug("getting file from container: %s:%s", container, filename)
|
||||
client = self._get_client(container)
|
||||
@ -194,7 +194,10 @@ class Files(Endpoint):
|
||||
return downloaded
|
||||
|
||||
def upload_file(
|
||||
self, container: str, file_path: str, blob_name: Optional[str] = None
|
||||
self,
|
||||
container: primitives.Container,
|
||||
file_path: str,
|
||||
blob_name: Optional[str] = None,
|
||||
) -> None:
|
||||
""" uploads a file to a container """
|
||||
if not blob_name:
|
||||
@ -212,7 +215,7 @@ class Files(Endpoint):
|
||||
client = self._get_client(container)
|
||||
client.upload_file(file_path, blob_name)
|
||||
|
||||
def upload_dir(self, container: str, dir_path: str) -> None:
|
||||
def upload_dir(self, container: primitives.Container, dir_path: str) -> None:
|
||||
""" uploads a directory to a container """
|
||||
|
||||
self.logger.debug("uploading directory to container %s:%s", container, dir_path)
|
||||
@ -476,7 +479,9 @@ class Repro(Endpoint):
|
||||
"GET", models.Repro, data=requests.ReproGet(vm_id=vm_id_expanded)
|
||||
)
|
||||
|
||||
def create(self, container: str, path: str, duration: int = 24) -> models.Repro:
|
||||
def create(
|
||||
self, container: primitives.Container, path: str, duration: int = 24
|
||||
) -> models.Repro:
|
||||
""" Create a Reproduction VM from a Crash Report """
|
||||
self.logger.info(
|
||||
"creating repro vm: %s %s (%d hours)", container, path, duration
|
||||
@ -651,7 +656,7 @@ class Repro(Endpoint):
|
||||
|
||||
def create_and_connect(
|
||||
self,
|
||||
container: str,
|
||||
container: primitives.Container,
|
||||
path: str,
|
||||
duration: int = 24,
|
||||
delete_after_use: bool = False,
|
||||
@ -670,14 +675,16 @@ class Notifications(Endpoint):
|
||||
endpoint = "notifications"
|
||||
|
||||
def create(
|
||||
self, container: str, config: models.NotificationConfig
|
||||
self, container: primitives.Container, config: models.NotificationConfig
|
||||
) -> models.Notification:
|
||||
""" Create a notification based on a config file """
|
||||
|
||||
config = requests.NotificationCreate(container=container, config=config.config)
|
||||
return self._req_model("POST", models.Notification, data=config)
|
||||
|
||||
def create_teams(self, container: str, url: str) -> models.Notification:
|
||||
def create_teams(
|
||||
self, container: primitives.Container, url: str
|
||||
) -> models.Notification:
|
||||
""" Create a Teams notification integration """
|
||||
|
||||
self.logger.debug("create teams notification integration: %s", container)
|
||||
@ -687,7 +694,7 @@ class Notifications(Endpoint):
|
||||
|
||||
def create_ado(
|
||||
self,
|
||||
container: str,
|
||||
container: primitives.Container,
|
||||
project: str,
|
||||
base_url: str,
|
||||
auth_token: str,
|
||||
@ -804,7 +811,7 @@ class Tasks(Endpoint):
|
||||
ensemble_sync_delay: Optional[int] = None,
|
||||
generator_exe: Optional[str] = None,
|
||||
generator_options: Optional[List[str]] = None,
|
||||
pool_name: str,
|
||||
pool_name: primitives.PoolName,
|
||||
prereq_tasks: Optional[List[UUID]] = None,
|
||||
reboot_after_setup: bool = False,
|
||||
rename_output: bool = False,
|
||||
@ -1049,7 +1056,7 @@ class Pool(Endpoint):
|
||||
),
|
||||
)
|
||||
|
||||
def get_config(self, pool_name: str) -> models.AgentConfig:
|
||||
def get_config(self, pool_name: primitives.PoolName) -> models.AgentConfig:
|
||||
""" Get the agent configuration for the pool """
|
||||
|
||||
pool = self.get(pool_name)
|
||||
@ -1168,17 +1175,19 @@ class Node(Endpoint):
|
||||
*,
|
||||
state: Optional[List[enums.NodeState]] = None,
|
||||
scaleset_id: Optional[UUID_EXPANSION] = None,
|
||||
pool_name: Optional[str] = None,
|
||||
pool_name: Optional[primitives.PoolName] = None,
|
||||
) -> List[models.Node]:
|
||||
self.logger.debug("list nodes")
|
||||
scaleset_id_expanded: Optional[UUID] = None
|
||||
|
||||
if pool_name is not None:
|
||||
pool_name = self._disambiguate(
|
||||
"name",
|
||||
pool_name,
|
||||
lambda x: False,
|
||||
lambda: [x.name for x in self.onefuzz.pools.list()],
|
||||
pool_name = primitives.PoolName(
|
||||
self._disambiguate(
|
||||
"name",
|
||||
str(pool_name),
|
||||
lambda x: False,
|
||||
lambda: [x.name for x in self.onefuzz.pools.list()],
|
||||
)
|
||||
)
|
||||
|
||||
if scaleset_id is not None:
|
||||
@ -1242,12 +1251,12 @@ class Scaleset(Endpoint):
|
||||
|
||||
def create(
|
||||
self,
|
||||
pool_name: str,
|
||||
pool_name: primitives.PoolName,
|
||||
size: int,
|
||||
*,
|
||||
image: Optional[str] = None,
|
||||
vm_sku: Optional[str] = "Standard_D2s_v3",
|
||||
region: Optional[str] = None,
|
||||
region: Optional[primitives.Region] = None,
|
||||
spot_instances: bool = False,
|
||||
tags: Optional[Dict[str, str]] = None,
|
||||
) -> models.Scaleset:
|
||||
@ -1375,7 +1384,7 @@ class ScalesetProxy(Endpoint):
|
||||
),
|
||||
)
|
||||
|
||||
def reset(self, region: str) -> responses.BoolResult:
|
||||
def reset(self, region: primitives.Region) -> responses.BoolResult:
|
||||
""" Reset the proxy for an existing region """
|
||||
|
||||
return self._req_model(
|
||||
|
Reference in New Issue
Block a user