move SDK to use request models rather than hand-crafted json (#191)

This commit is contained in:
bmc-msft
2020-10-23 08:39:45 -04:00
committed by GitHub
parent d769072343
commit 1c06d7085a

View File

@ -10,7 +10,7 @@ import re
import subprocess # nosec
import uuid
from shutil import which
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar
from uuid import UUID
import pkg_resources
@ -66,12 +66,16 @@ class Endpoint:
method: str,
model: Type[A],
*,
json_data: Any = None,
params: Optional[Dict[str, Any]] = None,
data: Optional[BaseModel] = None,
as_params: bool = False,
) -> A:
if as_params:
response = self.onefuzz._backend.request(method, self.endpoint, params=data)
else:
response = self.onefuzz._backend.request(
method, self.endpoint, json_data=json_data, params=params
method, self.endpoint, json_data=data
)
return model.parse_obj(response)
def _req_model_list(
@ -79,11 +83,14 @@ class Endpoint:
method: str,
model: Type[A],
*,
json_data: Any = None,
params: Optional[Dict[str, Any]] = None,
data: Optional[BaseModel] = None,
as_params: bool = False,
) -> List[A]:
if as_params:
response = self.onefuzz._backend.request(method, self.endpoint, params=data)
else:
response = self.onefuzz._backend.request(
method, self.endpoint, json_data=json_data, params=params
method, self.endpoint, json_data=data
)
return [model.parse_obj(x) for x in response]
@ -240,7 +247,9 @@ class Containers(Endpoint):
def get(self, name: str) -> responses.ContainerInfo:
""" Get a fully qualified SAS URL for a container """
self.logger.debug("get container: %s", name)
return self._req_model("GET", responses.ContainerInfo, json_data={"name": name})
return self._req_model(
"GET", responses.ContainerInfo, data=requests.ContainerGet(name=name)
)
def create(
self, name: str, metadata: Optional[Dict[str, str]] = None
@ -250,13 +259,15 @@ class Containers(Endpoint):
return self._req_model(
"POST",
responses.ContainerInfo,
json_data={"name": name, "metadata": metadata},
data=requests.ContainerCreate(name=name, metadata=metadata),
)
def delete(self, name: str) -> responses.BoolResult:
""" Delete a storage container """
self.logger.debug("delete container: %s", name)
return self._req_model("DELETE", responses.BoolResult, json_data={"name": name})
return self._req_model(
"DELETE", responses.BoolResult, data=requests.ContainerDelete(name=name)
)
def list(self) -> List[responses.ContainerInfoBase]:
""" Get a list of containers """
@ -276,7 +287,9 @@ class Repro(Endpoint):
)
self.logger.debug("get repro vm: %s", vm_id_expanded)
return self._req_model("GET", models.Repro, json_data={"vm_id": vm_id_expanded})
return self._req_model(
"GET", models.Repro, data=requests.ReproGet(vm_id=vm_id_expanded)
)
def create(self, container: str, path: str, duration: int = 24) -> models.Repro:
""" Create a Reproduction VM from a Crash Report """
@ -286,7 +299,7 @@ class Repro(Endpoint):
return self._req_model(
"POST",
models.Repro,
json_data={"container": container, "path": path, "duration": duration},
data=models.ReproConfig(container=container, path=path, duration=duration),
)
def delete(self, vm_id: UUID_EXPANSION) -> models.Repro:
@ -297,13 +310,13 @@ class Repro(Endpoint):
self.logger.debug("deleting repro vm: %s", vm_id_expanded)
return self._req_model(
"DELETE", models.Repro, json_data={"vm_id": vm_id_expanded}
"DELETE", models.Repro, data=requests.ReproGet(vm_id=vm_id_expanded)
)
def list(self) -> List[models.Repro]:
""" List all VMs """
self.logger.debug("listing repro vms")
return self._req_model_list("GET", models.Repro, json_data={})
return self._req_model_list("GET", models.Repro, data=requests.ReproGet())
def _dbg_linux(
self, repro: models.Repro, debug_command: Optional[str]
@ -469,7 +482,7 @@ class Notifications(Endpoint):
""" Create a notification based on a config file """
config = requests.NotificationCreate(container=container, config=config.config)
return self._req_model("POST", models.Notification, json_data=config)
return self._req_model("POST", models.Notification, data=config)
def create_teams(self, container: str, url: str) -> models.Notification:
""" Create a Teams notification integration """
@ -533,7 +546,7 @@ class Notifications(Endpoint):
return self._req_model(
"DELETE",
models.Notification,
json_data={"notification_id": notification_id_expanded},
data=requests.NotificationGet(notification_id=notification_id_expanded),
)
def list(self) -> List[models.Notification]:
@ -558,7 +571,7 @@ class Tasks(Endpoint):
self.logger.debug("delete task: %s", task_id_expanded)
return self._req_model(
"DELETE", models.Task, json_data={"task_id": task_id_expanded}
"DELETE", models.Task, data=requests.TaskGet(task_id=task_id_expanded)
)
def get(self, task_id: UUID_EXPANSION) -> models.Task:
@ -570,9 +583,14 @@ class Tasks(Endpoint):
self.logger.debug("get task: %s", task_id_expanded)
return self._req_model(
"GET", models.Task, json_data={"task_id": task_id_expanded}
"GET", models.Task, data=requests.TaskGet(task_id=task_id_expanded)
)
def create_with_config(self, config: models.TaskConfig) -> models.Task:
""" Create a Task using TaskConfig """
return self._req_model("POST", models.Task, data=config)
def create(
self,
job_id: UUID_EXPANSION,
@ -632,48 +650,47 @@ class Tasks(Endpoint):
prereq_tasks = []
containers_submit = []
all_containers = [x.name for x in self.onefuzz.containers.list()]
for (container_type, container) in containers:
if container not in all_containers:
raise Exception("invalid container: %s" % container)
containers_submit.append({"type": container_type, "name": container})
containers_submit.append(
models.TaskContainers(name=container, type=container_type)
)
data = {
"job_id": job_id_expanded,
"prereq_tasks": prereq_tasks,
"task": {
"type": task_type,
"target_exe": target_exe,
"duration": duration,
"target_env": target_env,
"target_options": target_options,
"target_workers": target_workers,
"target_options_merge": target_options_merge,
"target_timeout": target_timeout,
"rename_output": rename_output,
"supervisor_exe": supervisor_exe,
"supervisor_options": supervisor_options,
"supervisor_env": supervisor_env,
"supervisor_input_marker": supervisor_input_marker,
"analyzer_exe": analyzer_exe,
"analyzer_options": analyzer_options,
"analyzer_env": analyzer_env,
"stats_file": stats_file,
"stats_format": stats_format,
"generator_exe": generator_exe,
"generator_options": generator_options,
"wait_for_files": task_wait_for_files,
"reboot_after_setup": reboot_after_setup,
"check_asan_log": check_asan_log,
"check_debugger": check_debugger,
"check_retry_count": check_retry_count,
},
"pool": {"count": vm_count, "pool_name": pool_name},
"containers": containers_submit,
"tags": tags,
}
config = models.TaskConfig(
job_id=job_id_expanded,
prereq_tasks=prereq_tasks,
task=models.TaskDetails(
type=task_type,
duration=duration,
target_exe=target_exe,
target_env=target_env,
target_options=target_options,
target_options_merge=target_options_merge,
target_timeout=target_timeout,
target_workers=target_workers,
rename_output=rename_output,
supervisor_exe=supervisor_exe,
supervisor_options=supervisor_options,
supervisor_env=supervisor_env,
supervisor_input_marker=supervisor_input_marker,
analyzer_exe=analyzer_exe,
analyzer_env=analyzer_env,
analyzer_options=analyzer_options,
stats_file=stats_file,
stats_format=stats_format,
generator_exe=generator_exe,
generator_options=generator_options,
wait_for_files=task_wait_for_files,
reboot_after_setup=reboot_after_setup,
check_asan_log=check_asan_log,
check_debugger=check_debugger,
check_retry_count=check_retry_count,
),
pool=models.TaskPool(count=vm_count, pool_name=pool_name),
containers=containers_submit,
tags=tags,
)
return self._req_model("POST", models.Task, json_data=data)
return self.create_with_config(config)
def list(
self,
@ -691,9 +708,11 @@ class Tasks(Endpoint):
lambda: [str(x.job_id) for x in self.onefuzz.jobs.list()],
)
data = {"state": state, "job_id": job_id_expanded}
return self._req_model_list("GET", models.Task, json_data=data)
return self._req_model_list(
"GET",
models.Task,
data=requests.TaskSearch(job_id=job_id_expanded, state=state),
)
class JobContainers(Endpoint):
@ -754,7 +773,7 @@ class Jobs(Endpoint):
self.logger.debug("delete job: %s", job_id_expanded)
return self._req_model(
"DELETE", models.Job, json_data={"job_id": job_id_expanded}
"DELETE", models.Job, data=requests.JobGet(job_id=job_id_expanded)
)
def get(self, job_id: UUID_EXPANSION) -> models.Job:
@ -763,27 +782,31 @@ class Jobs(Endpoint):
"job_id", job_id, lambda: [str(x.job_id) for x in self.list()]
)
self.logger.debug("get job: %s", job_id_expanded)
job = self._req_model("GET", models.Job, json_data={"job_id": job_id_expanded})
# TODO
# job["tasks"] = self.onefuzz.tasks.list(job_id=job["job_id"])
job = self._req_model(
"GET", models.Job, data=requests.JobGet(job_id=job_id_expanded)
)
return job
def create_with_config(self, config: models.JobConfig) -> models.Job:
""" Create a job """
self.logger.debug(
"create job: project:%s name:%s build:%s",
config.project,
config.name,
config.build,
)
return self._req_model(
"POST",
models.Job,
data=config,
)
def create(
self, project: str, name: str, build: str, duration: int = 24
) -> models.Job:
""" Create a job """
self.logger.debug(
"create job: project:%s name:%s build:%s", project, name, build
)
return self._req_model(
"POST",
models.Job,
json_data={
"project": project,
"name": name,
"build": build,
"duration": duration,
},
return self.create_with_config(
models.JobConfig(project=project, name=name, build=build, duration=duration)
)
def list(
@ -792,10 +815,9 @@ class Jobs(Endpoint):
) -> List[models.Job]:
""" Get information about all jobs """
self.logger.debug("list jobs")
data = {"state": job_state}
jobs = self._req_model_list("GET", models.Job, json_data=data)
return jobs
return self._req_model_list(
"GET", models.Job, data=requests.JobSearch(state=job_state)
)
class Pool(Endpoint):
@ -823,20 +845,15 @@ class Pool(Endpoint):
return self._req_model(
"POST",
models.Pool,
json_data={
"name": name,
"os": os,
"arch": arch,
"managed": managed,
"client_id": client_id,
"autoscale": None,
},
data=requests.PoolCreate(
name=name, os=os, arch=arch, managed=managed, client_id=client_id
),
)
def get_config(self, pool_name: str) -> models.AgentConfig:
""" Get the agent configuration for the pool """
pool = self._req_model("GET", models.Pool, json_data={"name": pool_name})
pool = self.get(pool_name)
if pool.config is None:
raise Exception("Missing AgentConfig in response")
@ -858,7 +875,7 @@ class Pool(Endpoint):
return self._req_model(
"DELETE",
responses.BoolResult,
json_data={"name": expanded_name, "now": now},
data=requests.PoolStop(name=expanded_name, now=now),
)
def get(self, name: str) -> models.Pool:
@ -867,13 +884,17 @@ class Pool(Endpoint):
"name", name, lambda x: False, lambda: [x.name for x in self.list()]
)
return self._req_model("GET", models.Pool, json_data={"name": expanded_name})
return self._req_model(
"GET", models.Pool, data=requests.PoolSearch(name=expanded_name)
)
def list(
self, *, state: Optional[List[enums.PoolState]] = None
) -> List[models.Pool]:
self.logger.debug("list worker pools")
return self._req_model_list("GET", models.Pool, json_data={"state": state})
return self._req_model_list(
"GET", models.Pool, data=requests.PoolSearch(state=state)
)
class Node(Endpoint):
@ -890,7 +911,7 @@ class Node(Endpoint):
)
return self._req_model(
"GET", models.Node, json_data={"machine_id": machine_id_expanded}
"GET", models.Node, data=requests.NodeGet(machine_id=machine_id_expanded)
)
def halt(self, machine_id: UUID_EXPANSION) -> responses.BoolResult:
@ -904,7 +925,7 @@ class Node(Endpoint):
return self._req_model(
"DELETE",
responses.BoolResult,
json_data={"machine_id": machine_id_expanded},
data=requests.NodeGet(machine_id=machine_id_expanded),
)
def reimage(self, machine_id: UUID_EXPANSION) -> responses.BoolResult:
@ -916,7 +937,9 @@ class Node(Endpoint):
)
return self._req_model(
"PATCH", responses.BoolResult, json_data={"machine_id": machine_id_expanded}
"PATCH",
responses.BoolResult,
data=requests.NodeGet(machine_id=machine_id_expanded),
)
def list(
@ -947,11 +970,9 @@ class Node(Endpoint):
return self._req_model_list(
"GET",
models.Node,
json_data={
"state": state,
"scaleset_id": scaleset_id_expanded,
"pool_name": pool_name,
},
data=requests.NodeSearch(
scaleset_id=scaleset_id_expanded, state=state, pool_name=pool_name
),
)
@ -1006,15 +1027,15 @@ class Scaleset(Endpoint):
return self._req_model(
"POST",
models.Scaleset,
json_data={
"pool_name": pool_name,
"vm_sku": vm_sku,
"image": image,
"region": region,
"size": size,
"spot_instances": spot_instances,
"tags": tags,
},
data=requests.ScalesetCreate(
pool_name=pool_name,
vm_sku=vm_sku,
image=image,
region=region,
size=size,
spot_instances=spot_instances,
tags=tags,
),
)
def shutdown(
@ -1030,7 +1051,7 @@ class Scaleset(Endpoint):
return self._req_model(
"DELETE",
responses.BoolResult,
json_data={"scaleset_id": scaleset_id_expanded, "now": now},
data=requests.ScalesetStop(scaleset_id=scaleset_id_expanded, now=now),
)
def get(
@ -1046,10 +1067,9 @@ class Scaleset(Endpoint):
return self._req_model(
"GET",
models.Scaleset,
json_data={
"scaleset_id": scaleset_id_expanded,
"include_auth": include_auth,
},
data=requests.ScalesetSearch(
scaleset_id=scaleset_id_expanded, include_auth=include_auth
),
)
def update(
@ -1065,7 +1085,7 @@ class Scaleset(Endpoint):
return self._req_model(
"PATCH",
models.Scaleset,
json_data={"scaleset_id": scaleset_id_expanded, "size": size},
data=requests.ScalesetUpdate(scaleset_id=scaleset_id_expanded, size=size),
)
def list(
@ -1074,7 +1094,9 @@ class Scaleset(Endpoint):
state: Optional[List[enums.ScalesetState]] = None,
) -> List[models.Scaleset]:
self.logger.debug("list scalesets")
return self._req_model_list("GET", models.Scaleset, json_data={"state": state})
return self._req_model_list(
"GET", models.Scaleset, data=requests.ScalesetSearch(state=state)
)
class ScalesetProxy(Endpoint):
@ -1105,18 +1127,18 @@ class ScalesetProxy(Endpoint):
return self._req_model(
"DELETE",
responses.BoolResult,
json_data={
"scaleset_id": scaleset.scaleset_id,
"machine_id": machine_id_expanded,
"dst_port": dst_port,
},
data=requests.ProxyDelete(
scaleset_id=scaleset.scaleset_id,
machine_id=machine_id_expanded,
dst_port=dst_port,
),
)
def reset(self, region: str) -> responses.BoolResult:
""" Reset the proxy for an existing region """
return self._req_model(
"PATCH", responses.BoolResult, json_data={"region": region}
"PATCH", responses.BoolResult, data=requests.ProxyReset(region=region)
)
def get(
@ -1134,11 +1156,11 @@ class ScalesetProxy(Endpoint):
proxy = self._req_model(
"GET",
responses.ProxyGetResult,
json_data={
"scaleset_id": scaleset.scaleset_id,
"machine_id": machine_id_expanded,
"dst_port": dst_port,
},
data=requests.ProxyGet(
scaleset_id=scaleset.scaleset_id,
machine_id=machine_id_expanded,
dst_port=dst_port,
),
)
return proxy
@ -1165,12 +1187,12 @@ class ScalesetProxy(Endpoint):
return self._req_model(
"POST",
responses.ProxyGetResult,
json_data={
"scaleset_id": scaleset.scaleset_id,
"machine_id": machine_id_expanded,
"dst_port": dst_port,
"duration": duration,
},
data=requests.ProxyCreate(
scaleset_id=scaleset.scaleset_id,
machine_id=machine_id_expanded,
dst_port=dst_port,
duration=duration,
),
)