move SDK to use request models rather than hand-crafted json (#191)

This commit is contained in:
bmc-msft
2020-10-23 08:39:45 -04:00
committed by GitHub
parent d769072343
commit 1c06d7085a

View File

@ -10,7 +10,7 @@ import re
import subprocess # nosec import subprocess # nosec
import uuid import uuid
from shutil import which from shutil import which
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar
from uuid import UUID from uuid import UUID
import pkg_resources import pkg_resources
@ -66,12 +66,16 @@ class Endpoint:
method: str, method: str,
model: Type[A], model: Type[A],
*, *,
json_data: Any = None, data: Optional[BaseModel] = None,
params: Optional[Dict[str, Any]] = None, as_params: bool = False,
) -> A: ) -> A:
if as_params:
response = self.onefuzz._backend.request(method, self.endpoint, params=data)
else:
response = self.onefuzz._backend.request( response = self.onefuzz._backend.request(
method, self.endpoint, json_data=json_data, params=params method, self.endpoint, json_data=data
) )
return model.parse_obj(response) return model.parse_obj(response)
def _req_model_list( def _req_model_list(
@ -79,11 +83,14 @@ class Endpoint:
method: str, method: str,
model: Type[A], model: Type[A],
*, *,
json_data: Any = None, data: Optional[BaseModel] = None,
params: Optional[Dict[str, Any]] = None, as_params: bool = False,
) -> List[A]: ) -> List[A]:
if as_params:
response = self.onefuzz._backend.request(method, self.endpoint, params=data)
else:
response = self.onefuzz._backend.request( response = self.onefuzz._backend.request(
method, self.endpoint, json_data=json_data, params=params method, self.endpoint, json_data=data
) )
return [model.parse_obj(x) for x in response] return [model.parse_obj(x) for x in response]
@ -240,7 +247,9 @@ class Containers(Endpoint):
def get(self, name: str) -> responses.ContainerInfo: def get(self, name: str) -> responses.ContainerInfo:
""" Get a fully qualified SAS URL for a container """ """ Get a fully qualified SAS URL for a container """
self.logger.debug("get container: %s", name) self.logger.debug("get container: %s", name)
return self._req_model("GET", responses.ContainerInfo, json_data={"name": name}) return self._req_model(
"GET", responses.ContainerInfo, data=requests.ContainerGet(name=name)
)
def create( def create(
self, name: str, metadata: Optional[Dict[str, str]] = None self, name: str, metadata: Optional[Dict[str, str]] = None
@ -250,13 +259,15 @@ class Containers(Endpoint):
return self._req_model( return self._req_model(
"POST", "POST",
responses.ContainerInfo, responses.ContainerInfo,
json_data={"name": name, "metadata": metadata}, data=requests.ContainerCreate(name=name, metadata=metadata),
) )
def delete(self, name: str) -> responses.BoolResult: def delete(self, name: str) -> responses.BoolResult:
""" Delete a storage container """ """ Delete a storage container """
self.logger.debug("delete container: %s", name) self.logger.debug("delete container: %s", name)
return self._req_model("DELETE", responses.BoolResult, json_data={"name": name}) return self._req_model(
"DELETE", responses.BoolResult, data=requests.ContainerDelete(name=name)
)
def list(self) -> List[responses.ContainerInfoBase]: def list(self) -> List[responses.ContainerInfoBase]:
""" Get a list of containers """ """ Get a list of containers """
@ -276,7 +287,9 @@ class Repro(Endpoint):
) )
self.logger.debug("get repro vm: %s", vm_id_expanded) self.logger.debug("get repro vm: %s", vm_id_expanded)
return self._req_model("GET", models.Repro, json_data={"vm_id": vm_id_expanded}) return self._req_model(
"GET", models.Repro, data=requests.ReproGet(vm_id=vm_id_expanded)
)
def create(self, container: str, path: str, duration: int = 24) -> models.Repro: def create(self, container: str, path: str, duration: int = 24) -> models.Repro:
""" Create a Reproduction VM from a Crash Report """ """ Create a Reproduction VM from a Crash Report """
@ -286,7 +299,7 @@ class Repro(Endpoint):
return self._req_model( return self._req_model(
"POST", "POST",
models.Repro, models.Repro,
json_data={"container": container, "path": path, "duration": duration}, data=models.ReproConfig(container=container, path=path, duration=duration),
) )
def delete(self, vm_id: UUID_EXPANSION) -> models.Repro: def delete(self, vm_id: UUID_EXPANSION) -> models.Repro:
@ -297,13 +310,13 @@ class Repro(Endpoint):
self.logger.debug("deleting repro vm: %s", vm_id_expanded) self.logger.debug("deleting repro vm: %s", vm_id_expanded)
return self._req_model( return self._req_model(
"DELETE", models.Repro, json_data={"vm_id": vm_id_expanded} "DELETE", models.Repro, data=requests.ReproGet(vm_id=vm_id_expanded)
) )
def list(self) -> List[models.Repro]: def list(self) -> List[models.Repro]:
""" List all VMs """ """ List all VMs """
self.logger.debug("listing repro vms") self.logger.debug("listing repro vms")
return self._req_model_list("GET", models.Repro, json_data={}) return self._req_model_list("GET", models.Repro, data=requests.ReproGet())
def _dbg_linux( def _dbg_linux(
self, repro: models.Repro, debug_command: Optional[str] self, repro: models.Repro, debug_command: Optional[str]
@ -469,7 +482,7 @@ class Notifications(Endpoint):
""" Create a notification based on a config file """ """ Create a notification based on a config file """
config = requests.NotificationCreate(container=container, config=config.config) config = requests.NotificationCreate(container=container, config=config.config)
return self._req_model("POST", models.Notification, json_data=config) return self._req_model("POST", models.Notification, data=config)
def create_teams(self, container: str, url: str) -> models.Notification: def create_teams(self, container: str, url: str) -> models.Notification:
""" Create a Teams notification integration """ """ Create a Teams notification integration """
@ -533,7 +546,7 @@ class Notifications(Endpoint):
return self._req_model( return self._req_model(
"DELETE", "DELETE",
models.Notification, models.Notification,
json_data={"notification_id": notification_id_expanded}, data=requests.NotificationGet(notification_id=notification_id_expanded),
) )
def list(self) -> List[models.Notification]: def list(self) -> List[models.Notification]:
@ -558,7 +571,7 @@ class Tasks(Endpoint):
self.logger.debug("delete task: %s", task_id_expanded) self.logger.debug("delete task: %s", task_id_expanded)
return self._req_model( return self._req_model(
"DELETE", models.Task, json_data={"task_id": task_id_expanded} "DELETE", models.Task, data=requests.TaskGet(task_id=task_id_expanded)
) )
def get(self, task_id: UUID_EXPANSION) -> models.Task: def get(self, task_id: UUID_EXPANSION) -> models.Task:
@ -570,9 +583,14 @@ class Tasks(Endpoint):
self.logger.debug("get task: %s", task_id_expanded) self.logger.debug("get task: %s", task_id_expanded)
return self._req_model( return self._req_model(
"GET", models.Task, json_data={"task_id": task_id_expanded} "GET", models.Task, data=requests.TaskGet(task_id=task_id_expanded)
) )
def create_with_config(self, config: models.TaskConfig) -> models.Task:
""" Create a Task using TaskConfig """
return self._req_model("POST", models.Task, data=config)
def create( def create(
self, self,
job_id: UUID_EXPANSION, job_id: UUID_EXPANSION,
@ -632,48 +650,47 @@ class Tasks(Endpoint):
prereq_tasks = [] prereq_tasks = []
containers_submit = [] containers_submit = []
all_containers = [x.name for x in self.onefuzz.containers.list()]
for (container_type, container) in containers: for (container_type, container) in containers:
if container not in all_containers: containers_submit.append(
raise Exception("invalid container: %s" % container) models.TaskContainers(name=container, type=container_type)
containers_submit.append({"type": container_type, "name": container}) )
data = { config = models.TaskConfig(
"job_id": job_id_expanded, job_id=job_id_expanded,
"prereq_tasks": prereq_tasks, prereq_tasks=prereq_tasks,
"task": { task=models.TaskDetails(
"type": task_type, type=task_type,
"target_exe": target_exe, duration=duration,
"duration": duration, target_exe=target_exe,
"target_env": target_env, target_env=target_env,
"target_options": target_options, target_options=target_options,
"target_workers": target_workers, target_options_merge=target_options_merge,
"target_options_merge": target_options_merge, target_timeout=target_timeout,
"target_timeout": target_timeout, target_workers=target_workers,
"rename_output": rename_output, rename_output=rename_output,
"supervisor_exe": supervisor_exe, supervisor_exe=supervisor_exe,
"supervisor_options": supervisor_options, supervisor_options=supervisor_options,
"supervisor_env": supervisor_env, supervisor_env=supervisor_env,
"supervisor_input_marker": supervisor_input_marker, supervisor_input_marker=supervisor_input_marker,
"analyzer_exe": analyzer_exe, analyzer_exe=analyzer_exe,
"analyzer_options": analyzer_options, analyzer_env=analyzer_env,
"analyzer_env": analyzer_env, analyzer_options=analyzer_options,
"stats_file": stats_file, stats_file=stats_file,
"stats_format": stats_format, stats_format=stats_format,
"generator_exe": generator_exe, generator_exe=generator_exe,
"generator_options": generator_options, generator_options=generator_options,
"wait_for_files": task_wait_for_files, wait_for_files=task_wait_for_files,
"reboot_after_setup": reboot_after_setup, reboot_after_setup=reboot_after_setup,
"check_asan_log": check_asan_log, check_asan_log=check_asan_log,
"check_debugger": check_debugger, check_debugger=check_debugger,
"check_retry_count": check_retry_count, check_retry_count=check_retry_count,
}, ),
"pool": {"count": vm_count, "pool_name": pool_name}, pool=models.TaskPool(count=vm_count, pool_name=pool_name),
"containers": containers_submit, containers=containers_submit,
"tags": tags, tags=tags,
} )
return self._req_model("POST", models.Task, json_data=data) return self.create_with_config(config)
def list( def list(
self, self,
@ -691,9 +708,11 @@ class Tasks(Endpoint):
lambda: [str(x.job_id) for x in self.onefuzz.jobs.list()], lambda: [str(x.job_id) for x in self.onefuzz.jobs.list()],
) )
data = {"state": state, "job_id": job_id_expanded} return self._req_model_list(
"GET",
return self._req_model_list("GET", models.Task, json_data=data) models.Task,
data=requests.TaskSearch(job_id=job_id_expanded, state=state),
)
class JobContainers(Endpoint): class JobContainers(Endpoint):
@ -754,7 +773,7 @@ class Jobs(Endpoint):
self.logger.debug("delete job: %s", job_id_expanded) self.logger.debug("delete job: %s", job_id_expanded)
return self._req_model( return self._req_model(
"DELETE", models.Job, json_data={"job_id": job_id_expanded} "DELETE", models.Job, data=requests.JobGet(job_id=job_id_expanded)
) )
def get(self, job_id: UUID_EXPANSION) -> models.Job: def get(self, job_id: UUID_EXPANSION) -> models.Job:
@ -763,27 +782,31 @@ class Jobs(Endpoint):
"job_id", job_id, lambda: [str(x.job_id) for x in self.list()] "job_id", job_id, lambda: [str(x.job_id) for x in self.list()]
) )
self.logger.debug("get job: %s", job_id_expanded) self.logger.debug("get job: %s", job_id_expanded)
job = self._req_model("GET", models.Job, json_data={"job_id": job_id_expanded}) job = self._req_model(
# TODO "GET", models.Job, data=requests.JobGet(job_id=job_id_expanded)
# job["tasks"] = self.onefuzz.tasks.list(job_id=job["job_id"]) )
return job return job
def create_with_config(self, config: models.JobConfig) -> models.Job:
""" Create a job """
self.logger.debug(
"create job: project:%s name:%s build:%s",
config.project,
config.name,
config.build,
)
return self._req_model(
"POST",
models.Job,
data=config,
)
def create( def create(
self, project: str, name: str, build: str, duration: int = 24 self, project: str, name: str, build: str, duration: int = 24
) -> models.Job: ) -> models.Job:
""" Create a job """ """ Create a job """
self.logger.debug( return self.create_with_config(
"create job: project:%s name:%s build:%s", project, name, build models.JobConfig(project=project, name=name, build=build, duration=duration)
)
return self._req_model(
"POST",
models.Job,
json_data={
"project": project,
"name": name,
"build": build,
"duration": duration,
},
) )
def list( def list(
@ -792,10 +815,9 @@ class Jobs(Endpoint):
) -> List[models.Job]: ) -> List[models.Job]:
""" Get information about all jobs """ """ Get information about all jobs """
self.logger.debug("list jobs") self.logger.debug("list jobs")
data = {"state": job_state} return self._req_model_list(
jobs = self._req_model_list("GET", models.Job, json_data=data) "GET", models.Job, data=requests.JobSearch(state=job_state)
)
return jobs
class Pool(Endpoint): class Pool(Endpoint):
@ -823,20 +845,15 @@ class Pool(Endpoint):
return self._req_model( return self._req_model(
"POST", "POST",
models.Pool, models.Pool,
json_data={ data=requests.PoolCreate(
"name": name, name=name, os=os, arch=arch, managed=managed, client_id=client_id
"os": os, ),
"arch": arch,
"managed": managed,
"client_id": client_id,
"autoscale": None,
},
) )
def get_config(self, pool_name: str) -> models.AgentConfig: def get_config(self, pool_name: str) -> models.AgentConfig:
""" Get the agent configuration for the pool """ """ Get the agent configuration for the pool """
pool = self._req_model("GET", models.Pool, json_data={"name": pool_name}) pool = self.get(pool_name)
if pool.config is None: if pool.config is None:
raise Exception("Missing AgentConfig in response") raise Exception("Missing AgentConfig in response")
@ -858,7 +875,7 @@ class Pool(Endpoint):
return self._req_model( return self._req_model(
"DELETE", "DELETE",
responses.BoolResult, responses.BoolResult,
json_data={"name": expanded_name, "now": now}, data=requests.PoolStop(name=expanded_name, now=now),
) )
def get(self, name: str) -> models.Pool: def get(self, name: str) -> models.Pool:
@ -867,13 +884,17 @@ class Pool(Endpoint):
"name", name, lambda x: False, lambda: [x.name for x in self.list()] "name", name, lambda x: False, lambda: [x.name for x in self.list()]
) )
return self._req_model("GET", models.Pool, json_data={"name": expanded_name}) return self._req_model(
"GET", models.Pool, data=requests.PoolSearch(name=expanded_name)
)
def list( def list(
self, *, state: Optional[List[enums.PoolState]] = None self, *, state: Optional[List[enums.PoolState]] = None
) -> List[models.Pool]: ) -> List[models.Pool]:
self.logger.debug("list worker pools") self.logger.debug("list worker pools")
return self._req_model_list("GET", models.Pool, json_data={"state": state}) return self._req_model_list(
"GET", models.Pool, data=requests.PoolSearch(state=state)
)
class Node(Endpoint): class Node(Endpoint):
@ -890,7 +911,7 @@ class Node(Endpoint):
) )
return self._req_model( return self._req_model(
"GET", models.Node, json_data={"machine_id": machine_id_expanded} "GET", models.Node, data=requests.NodeGet(machine_id=machine_id_expanded)
) )
def halt(self, machine_id: UUID_EXPANSION) -> responses.BoolResult: def halt(self, machine_id: UUID_EXPANSION) -> responses.BoolResult:
@ -904,7 +925,7 @@ class Node(Endpoint):
return self._req_model( return self._req_model(
"DELETE", "DELETE",
responses.BoolResult, responses.BoolResult,
json_data={"machine_id": machine_id_expanded}, data=requests.NodeGet(machine_id=machine_id_expanded),
) )
def reimage(self, machine_id: UUID_EXPANSION) -> responses.BoolResult: def reimage(self, machine_id: UUID_EXPANSION) -> responses.BoolResult:
@ -916,7 +937,9 @@ class Node(Endpoint):
) )
return self._req_model( return self._req_model(
"PATCH", responses.BoolResult, json_data={"machine_id": machine_id_expanded} "PATCH",
responses.BoolResult,
data=requests.NodeGet(machine_id=machine_id_expanded),
) )
def list( def list(
@ -947,11 +970,9 @@ class Node(Endpoint):
return self._req_model_list( return self._req_model_list(
"GET", "GET",
models.Node, models.Node,
json_data={ data=requests.NodeSearch(
"state": state, scaleset_id=scaleset_id_expanded, state=state, pool_name=pool_name
"scaleset_id": scaleset_id_expanded, ),
"pool_name": pool_name,
},
) )
@ -1006,15 +1027,15 @@ class Scaleset(Endpoint):
return self._req_model( return self._req_model(
"POST", "POST",
models.Scaleset, models.Scaleset,
json_data={ data=requests.ScalesetCreate(
"pool_name": pool_name, pool_name=pool_name,
"vm_sku": vm_sku, vm_sku=vm_sku,
"image": image, image=image,
"region": region, region=region,
"size": size, size=size,
"spot_instances": spot_instances, spot_instances=spot_instances,
"tags": tags, tags=tags,
}, ),
) )
def shutdown( def shutdown(
@ -1030,7 +1051,7 @@ class Scaleset(Endpoint):
return self._req_model( return self._req_model(
"DELETE", "DELETE",
responses.BoolResult, responses.BoolResult,
json_data={"scaleset_id": scaleset_id_expanded, "now": now}, data=requests.ScalesetStop(scaleset_id=scaleset_id_expanded, now=now),
) )
def get( def get(
@ -1046,10 +1067,9 @@ class Scaleset(Endpoint):
return self._req_model( return self._req_model(
"GET", "GET",
models.Scaleset, models.Scaleset,
json_data={ data=requests.ScalesetSearch(
"scaleset_id": scaleset_id_expanded, scaleset_id=scaleset_id_expanded, include_auth=include_auth
"include_auth": include_auth, ),
},
) )
def update( def update(
@ -1065,7 +1085,7 @@ class Scaleset(Endpoint):
return self._req_model( return self._req_model(
"PATCH", "PATCH",
models.Scaleset, models.Scaleset,
json_data={"scaleset_id": scaleset_id_expanded, "size": size}, data=requests.ScalesetUpdate(scaleset_id=scaleset_id_expanded, size=size),
) )
def list( def list(
@ -1074,7 +1094,9 @@ class Scaleset(Endpoint):
state: Optional[List[enums.ScalesetState]] = None, state: Optional[List[enums.ScalesetState]] = None,
) -> List[models.Scaleset]: ) -> List[models.Scaleset]:
self.logger.debug("list scalesets") self.logger.debug("list scalesets")
return self._req_model_list("GET", models.Scaleset, json_data={"state": state}) return self._req_model_list(
"GET", models.Scaleset, data=requests.ScalesetSearch(state=state)
)
class ScalesetProxy(Endpoint): class ScalesetProxy(Endpoint):
@ -1105,18 +1127,18 @@ class ScalesetProxy(Endpoint):
return self._req_model( return self._req_model(
"DELETE", "DELETE",
responses.BoolResult, responses.BoolResult,
json_data={ data=requests.ProxyDelete(
"scaleset_id": scaleset.scaleset_id, scaleset_id=scaleset.scaleset_id,
"machine_id": machine_id_expanded, machine_id=machine_id_expanded,
"dst_port": dst_port, dst_port=dst_port,
}, ),
) )
def reset(self, region: str) -> responses.BoolResult: def reset(self, region: str) -> responses.BoolResult:
""" Reset the proxy for an existing region """ """ Reset the proxy for an existing region """
return self._req_model( return self._req_model(
"PATCH", responses.BoolResult, json_data={"region": region} "PATCH", responses.BoolResult, data=requests.ProxyReset(region=region)
) )
def get( def get(
@ -1134,11 +1156,11 @@ class ScalesetProxy(Endpoint):
proxy = self._req_model( proxy = self._req_model(
"GET", "GET",
responses.ProxyGetResult, responses.ProxyGetResult,
json_data={ data=requests.ProxyGet(
"scaleset_id": scaleset.scaleset_id, scaleset_id=scaleset.scaleset_id,
"machine_id": machine_id_expanded, machine_id=machine_id_expanded,
"dst_port": dst_port, dst_port=dst_port,
}, ),
) )
return proxy return proxy
@ -1165,12 +1187,12 @@ class ScalesetProxy(Endpoint):
return self._req_model( return self._req_model(
"POST", "POST",
responses.ProxyGetResult, responses.ProxyGetResult,
json_data={ data=requests.ProxyCreate(
"scaleset_id": scaleset.scaleset_id, scaleset_id=scaleset.scaleset_id,
"machine_id": machine_id_expanded, machine_id=machine_id_expanded,
"dst_port": dst_port, dst_port=dst_port,
"duration": duration, duration=duration,
}, ),
) )