mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-18 04:38:09 +00:00
move SDK to use request models rather than hand-crafted json (#191)
This commit is contained in:
@ -10,7 +10,7 @@ import re
|
||||
import subprocess # nosec
|
||||
import uuid
|
||||
from shutil import which
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
import pkg_resources
|
||||
@ -66,12 +66,16 @@ class Endpoint:
|
||||
method: str,
|
||||
model: Type[A],
|
||||
*,
|
||||
json_data: Any = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[BaseModel] = None,
|
||||
as_params: bool = False,
|
||||
) -> A:
|
||||
response = self.onefuzz._backend.request(
|
||||
method, self.endpoint, json_data=json_data, params=params
|
||||
)
|
||||
if as_params:
|
||||
response = self.onefuzz._backend.request(method, self.endpoint, params=data)
|
||||
else:
|
||||
response = self.onefuzz._backend.request(
|
||||
method, self.endpoint, json_data=data
|
||||
)
|
||||
|
||||
return model.parse_obj(response)
|
||||
|
||||
def _req_model_list(
|
||||
@ -79,12 +83,15 @@ class Endpoint:
|
||||
method: str,
|
||||
model: Type[A],
|
||||
*,
|
||||
json_data: Any = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[BaseModel] = None,
|
||||
as_params: bool = False,
|
||||
) -> List[A]:
|
||||
response = self.onefuzz._backend.request(
|
||||
method, self.endpoint, json_data=json_data, params=params
|
||||
)
|
||||
if as_params:
|
||||
response = self.onefuzz._backend.request(method, self.endpoint, params=data)
|
||||
else:
|
||||
response = self.onefuzz._backend.request(
|
||||
method, self.endpoint, json_data=data
|
||||
)
|
||||
return [model.parse_obj(x) for x in response]
|
||||
|
||||
def _disambiguate(
|
||||
@ -240,7 +247,9 @@ class Containers(Endpoint):
|
||||
def get(self, name: str) -> responses.ContainerInfo:
|
||||
""" Get a fully qualified SAS URL for a container """
|
||||
self.logger.debug("get container: %s", name)
|
||||
return self._req_model("GET", responses.ContainerInfo, json_data={"name": name})
|
||||
return self._req_model(
|
||||
"GET", responses.ContainerInfo, data=requests.ContainerGet(name=name)
|
||||
)
|
||||
|
||||
def create(
|
||||
self, name: str, metadata: Optional[Dict[str, str]] = None
|
||||
@ -250,13 +259,15 @@ class Containers(Endpoint):
|
||||
return self._req_model(
|
||||
"POST",
|
||||
responses.ContainerInfo,
|
||||
json_data={"name": name, "metadata": metadata},
|
||||
data=requests.ContainerCreate(name=name, metadata=metadata),
|
||||
)
|
||||
|
||||
def delete(self, name: str) -> responses.BoolResult:
|
||||
""" Delete a storage container """
|
||||
self.logger.debug("delete container: %s", name)
|
||||
return self._req_model("DELETE", responses.BoolResult, json_data={"name": name})
|
||||
return self._req_model(
|
||||
"DELETE", responses.BoolResult, data=requests.ContainerDelete(name=name)
|
||||
)
|
||||
|
||||
def list(self) -> List[responses.ContainerInfoBase]:
|
||||
""" Get a list of containers """
|
||||
@ -276,7 +287,9 @@ class Repro(Endpoint):
|
||||
)
|
||||
|
||||
self.logger.debug("get repro vm: %s", vm_id_expanded)
|
||||
return self._req_model("GET", models.Repro, json_data={"vm_id": vm_id_expanded})
|
||||
return self._req_model(
|
||||
"GET", models.Repro, data=requests.ReproGet(vm_id=vm_id_expanded)
|
||||
)
|
||||
|
||||
def create(self, container: str, path: str, duration: int = 24) -> models.Repro:
|
||||
""" Create a Reproduction VM from a Crash Report """
|
||||
@ -286,7 +299,7 @@ class Repro(Endpoint):
|
||||
return self._req_model(
|
||||
"POST",
|
||||
models.Repro,
|
||||
json_data={"container": container, "path": path, "duration": duration},
|
||||
data=models.ReproConfig(container=container, path=path, duration=duration),
|
||||
)
|
||||
|
||||
def delete(self, vm_id: UUID_EXPANSION) -> models.Repro:
|
||||
@ -297,13 +310,13 @@ class Repro(Endpoint):
|
||||
|
||||
self.logger.debug("deleting repro vm: %s", vm_id_expanded)
|
||||
return self._req_model(
|
||||
"DELETE", models.Repro, json_data={"vm_id": vm_id_expanded}
|
||||
"DELETE", models.Repro, data=requests.ReproGet(vm_id=vm_id_expanded)
|
||||
)
|
||||
|
||||
def list(self) -> List[models.Repro]:
|
||||
""" List all VMs """
|
||||
self.logger.debug("listing repro vms")
|
||||
return self._req_model_list("GET", models.Repro, json_data={})
|
||||
return self._req_model_list("GET", models.Repro, data=requests.ReproGet())
|
||||
|
||||
def _dbg_linux(
|
||||
self, repro: models.Repro, debug_command: Optional[str]
|
||||
@ -469,7 +482,7 @@ class Notifications(Endpoint):
|
||||
""" Create a notification based on a config file """
|
||||
|
||||
config = requests.NotificationCreate(container=container, config=config.config)
|
||||
return self._req_model("POST", models.Notification, json_data=config)
|
||||
return self._req_model("POST", models.Notification, data=config)
|
||||
|
||||
def create_teams(self, container: str, url: str) -> models.Notification:
|
||||
""" Create a Teams notification integration """
|
||||
@ -533,7 +546,7 @@ class Notifications(Endpoint):
|
||||
return self._req_model(
|
||||
"DELETE",
|
||||
models.Notification,
|
||||
json_data={"notification_id": notification_id_expanded},
|
||||
data=requests.NotificationGet(notification_id=notification_id_expanded),
|
||||
)
|
||||
|
||||
def list(self) -> List[models.Notification]:
|
||||
@ -558,7 +571,7 @@ class Tasks(Endpoint):
|
||||
self.logger.debug("delete task: %s", task_id_expanded)
|
||||
|
||||
return self._req_model(
|
||||
"DELETE", models.Task, json_data={"task_id": task_id_expanded}
|
||||
"DELETE", models.Task, data=requests.TaskGet(task_id=task_id_expanded)
|
||||
)
|
||||
|
||||
def get(self, task_id: UUID_EXPANSION) -> models.Task:
|
||||
@ -570,9 +583,14 @@ class Tasks(Endpoint):
|
||||
self.logger.debug("get task: %s", task_id_expanded)
|
||||
|
||||
return self._req_model(
|
||||
"GET", models.Task, json_data={"task_id": task_id_expanded}
|
||||
"GET", models.Task, data=requests.TaskGet(task_id=task_id_expanded)
|
||||
)
|
||||
|
||||
def create_with_config(self, config: models.TaskConfig) -> models.Task:
|
||||
""" Create a Task using TaskConfig """
|
||||
|
||||
return self._req_model("POST", models.Task, data=config)
|
||||
|
||||
def create(
|
||||
self,
|
||||
job_id: UUID_EXPANSION,
|
||||
@ -632,48 +650,47 @@ class Tasks(Endpoint):
|
||||
prereq_tasks = []
|
||||
|
||||
containers_submit = []
|
||||
all_containers = [x.name for x in self.onefuzz.containers.list()]
|
||||
for (container_type, container) in containers:
|
||||
if container not in all_containers:
|
||||
raise Exception("invalid container: %s" % container)
|
||||
containers_submit.append({"type": container_type, "name": container})
|
||||
containers_submit.append(
|
||||
models.TaskContainers(name=container, type=container_type)
|
||||
)
|
||||
|
||||
data = {
|
||||
"job_id": job_id_expanded,
|
||||
"prereq_tasks": prereq_tasks,
|
||||
"task": {
|
||||
"type": task_type,
|
||||
"target_exe": target_exe,
|
||||
"duration": duration,
|
||||
"target_env": target_env,
|
||||
"target_options": target_options,
|
||||
"target_workers": target_workers,
|
||||
"target_options_merge": target_options_merge,
|
||||
"target_timeout": target_timeout,
|
||||
"rename_output": rename_output,
|
||||
"supervisor_exe": supervisor_exe,
|
||||
"supervisor_options": supervisor_options,
|
||||
"supervisor_env": supervisor_env,
|
||||
"supervisor_input_marker": supervisor_input_marker,
|
||||
"analyzer_exe": analyzer_exe,
|
||||
"analyzer_options": analyzer_options,
|
||||
"analyzer_env": analyzer_env,
|
||||
"stats_file": stats_file,
|
||||
"stats_format": stats_format,
|
||||
"generator_exe": generator_exe,
|
||||
"generator_options": generator_options,
|
||||
"wait_for_files": task_wait_for_files,
|
||||
"reboot_after_setup": reboot_after_setup,
|
||||
"check_asan_log": check_asan_log,
|
||||
"check_debugger": check_debugger,
|
||||
"check_retry_count": check_retry_count,
|
||||
},
|
||||
"pool": {"count": vm_count, "pool_name": pool_name},
|
||||
"containers": containers_submit,
|
||||
"tags": tags,
|
||||
}
|
||||
config = models.TaskConfig(
|
||||
job_id=job_id_expanded,
|
||||
prereq_tasks=prereq_tasks,
|
||||
task=models.TaskDetails(
|
||||
type=task_type,
|
||||
duration=duration,
|
||||
target_exe=target_exe,
|
||||
target_env=target_env,
|
||||
target_options=target_options,
|
||||
target_options_merge=target_options_merge,
|
||||
target_timeout=target_timeout,
|
||||
target_workers=target_workers,
|
||||
rename_output=rename_output,
|
||||
supervisor_exe=supervisor_exe,
|
||||
supervisor_options=supervisor_options,
|
||||
supervisor_env=supervisor_env,
|
||||
supervisor_input_marker=supervisor_input_marker,
|
||||
analyzer_exe=analyzer_exe,
|
||||
analyzer_env=analyzer_env,
|
||||
analyzer_options=analyzer_options,
|
||||
stats_file=stats_file,
|
||||
stats_format=stats_format,
|
||||
generator_exe=generator_exe,
|
||||
generator_options=generator_options,
|
||||
wait_for_files=task_wait_for_files,
|
||||
reboot_after_setup=reboot_after_setup,
|
||||
check_asan_log=check_asan_log,
|
||||
check_debugger=check_debugger,
|
||||
check_retry_count=check_retry_count,
|
||||
),
|
||||
pool=models.TaskPool(count=vm_count, pool_name=pool_name),
|
||||
containers=containers_submit,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
return self._req_model("POST", models.Task, json_data=data)
|
||||
return self.create_with_config(config)
|
||||
|
||||
def list(
|
||||
self,
|
||||
@ -691,9 +708,11 @@ class Tasks(Endpoint):
|
||||
lambda: [str(x.job_id) for x in self.onefuzz.jobs.list()],
|
||||
)
|
||||
|
||||
data = {"state": state, "job_id": job_id_expanded}
|
||||
|
||||
return self._req_model_list("GET", models.Task, json_data=data)
|
||||
return self._req_model_list(
|
||||
"GET",
|
||||
models.Task,
|
||||
data=requests.TaskSearch(job_id=job_id_expanded, state=state),
|
||||
)
|
||||
|
||||
|
||||
class JobContainers(Endpoint):
|
||||
@ -754,7 +773,7 @@ class Jobs(Endpoint):
|
||||
|
||||
self.logger.debug("delete job: %s", job_id_expanded)
|
||||
return self._req_model(
|
||||
"DELETE", models.Job, json_data={"job_id": job_id_expanded}
|
||||
"DELETE", models.Job, data=requests.JobGet(job_id=job_id_expanded)
|
||||
)
|
||||
|
||||
def get(self, job_id: UUID_EXPANSION) -> models.Job:
|
||||
@ -763,27 +782,31 @@ class Jobs(Endpoint):
|
||||
"job_id", job_id, lambda: [str(x.job_id) for x in self.list()]
|
||||
)
|
||||
self.logger.debug("get job: %s", job_id_expanded)
|
||||
job = self._req_model("GET", models.Job, json_data={"job_id": job_id_expanded})
|
||||
# TODO
|
||||
# job["tasks"] = self.onefuzz.tasks.list(job_id=job["job_id"])
|
||||
job = self._req_model(
|
||||
"GET", models.Job, data=requests.JobGet(job_id=job_id_expanded)
|
||||
)
|
||||
return job
|
||||
|
||||
def create_with_config(self, config: models.JobConfig) -> models.Job:
|
||||
""" Create a job """
|
||||
self.logger.debug(
|
||||
"create job: project:%s name:%s build:%s",
|
||||
config.project,
|
||||
config.name,
|
||||
config.build,
|
||||
)
|
||||
return self._req_model(
|
||||
"POST",
|
||||
models.Job,
|
||||
data=config,
|
||||
)
|
||||
|
||||
def create(
|
||||
self, project: str, name: str, build: str, duration: int = 24
|
||||
) -> models.Job:
|
||||
""" Create a job """
|
||||
self.logger.debug(
|
||||
"create job: project:%s name:%s build:%s", project, name, build
|
||||
)
|
||||
return self._req_model(
|
||||
"POST",
|
||||
models.Job,
|
||||
json_data={
|
||||
"project": project,
|
||||
"name": name,
|
||||
"build": build,
|
||||
"duration": duration,
|
||||
},
|
||||
return self.create_with_config(
|
||||
models.JobConfig(project=project, name=name, build=build, duration=duration)
|
||||
)
|
||||
|
||||
def list(
|
||||
@ -792,10 +815,9 @@ class Jobs(Endpoint):
|
||||
) -> List[models.Job]:
|
||||
""" Get information about all jobs """
|
||||
self.logger.debug("list jobs")
|
||||
data = {"state": job_state}
|
||||
jobs = self._req_model_list("GET", models.Job, json_data=data)
|
||||
|
||||
return jobs
|
||||
return self._req_model_list(
|
||||
"GET", models.Job, data=requests.JobSearch(state=job_state)
|
||||
)
|
||||
|
||||
|
||||
class Pool(Endpoint):
|
||||
@ -823,20 +845,15 @@ class Pool(Endpoint):
|
||||
return self._req_model(
|
||||
"POST",
|
||||
models.Pool,
|
||||
json_data={
|
||||
"name": name,
|
||||
"os": os,
|
||||
"arch": arch,
|
||||
"managed": managed,
|
||||
"client_id": client_id,
|
||||
"autoscale": None,
|
||||
},
|
||||
data=requests.PoolCreate(
|
||||
name=name, os=os, arch=arch, managed=managed, client_id=client_id
|
||||
),
|
||||
)
|
||||
|
||||
def get_config(self, pool_name: str) -> models.AgentConfig:
|
||||
""" Get the agent configuration for the pool """
|
||||
|
||||
pool = self._req_model("GET", models.Pool, json_data={"name": pool_name})
|
||||
pool = self.get(pool_name)
|
||||
|
||||
if pool.config is None:
|
||||
raise Exception("Missing AgentConfig in response")
|
||||
@ -858,7 +875,7 @@ class Pool(Endpoint):
|
||||
return self._req_model(
|
||||
"DELETE",
|
||||
responses.BoolResult,
|
||||
json_data={"name": expanded_name, "now": now},
|
||||
data=requests.PoolStop(name=expanded_name, now=now),
|
||||
)
|
||||
|
||||
def get(self, name: str) -> models.Pool:
|
||||
@ -867,13 +884,17 @@ class Pool(Endpoint):
|
||||
"name", name, lambda x: False, lambda: [x.name for x in self.list()]
|
||||
)
|
||||
|
||||
return self._req_model("GET", models.Pool, json_data={"name": expanded_name})
|
||||
return self._req_model(
|
||||
"GET", models.Pool, data=requests.PoolSearch(name=expanded_name)
|
||||
)
|
||||
|
||||
def list(
|
||||
self, *, state: Optional[List[enums.PoolState]] = None
|
||||
) -> List[models.Pool]:
|
||||
self.logger.debug("list worker pools")
|
||||
return self._req_model_list("GET", models.Pool, json_data={"state": state})
|
||||
return self._req_model_list(
|
||||
"GET", models.Pool, data=requests.PoolSearch(state=state)
|
||||
)
|
||||
|
||||
|
||||
class Node(Endpoint):
|
||||
@ -890,7 +911,7 @@ class Node(Endpoint):
|
||||
)
|
||||
|
||||
return self._req_model(
|
||||
"GET", models.Node, json_data={"machine_id": machine_id_expanded}
|
||||
"GET", models.Node, data=requests.NodeGet(machine_id=machine_id_expanded)
|
||||
)
|
||||
|
||||
def halt(self, machine_id: UUID_EXPANSION) -> responses.BoolResult:
|
||||
@ -904,7 +925,7 @@ class Node(Endpoint):
|
||||
return self._req_model(
|
||||
"DELETE",
|
||||
responses.BoolResult,
|
||||
json_data={"machine_id": machine_id_expanded},
|
||||
data=requests.NodeGet(machine_id=machine_id_expanded),
|
||||
)
|
||||
|
||||
def reimage(self, machine_id: UUID_EXPANSION) -> responses.BoolResult:
|
||||
@ -916,7 +937,9 @@ class Node(Endpoint):
|
||||
)
|
||||
|
||||
return self._req_model(
|
||||
"PATCH", responses.BoolResult, json_data={"machine_id": machine_id_expanded}
|
||||
"PATCH",
|
||||
responses.BoolResult,
|
||||
data=requests.NodeGet(machine_id=machine_id_expanded),
|
||||
)
|
||||
|
||||
def list(
|
||||
@ -947,11 +970,9 @@ class Node(Endpoint):
|
||||
return self._req_model_list(
|
||||
"GET",
|
||||
models.Node,
|
||||
json_data={
|
||||
"state": state,
|
||||
"scaleset_id": scaleset_id_expanded,
|
||||
"pool_name": pool_name,
|
||||
},
|
||||
data=requests.NodeSearch(
|
||||
scaleset_id=scaleset_id_expanded, state=state, pool_name=pool_name
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -1006,15 +1027,15 @@ class Scaleset(Endpoint):
|
||||
return self._req_model(
|
||||
"POST",
|
||||
models.Scaleset,
|
||||
json_data={
|
||||
"pool_name": pool_name,
|
||||
"vm_sku": vm_sku,
|
||||
"image": image,
|
||||
"region": region,
|
||||
"size": size,
|
||||
"spot_instances": spot_instances,
|
||||
"tags": tags,
|
||||
},
|
||||
data=requests.ScalesetCreate(
|
||||
pool_name=pool_name,
|
||||
vm_sku=vm_sku,
|
||||
image=image,
|
||||
region=region,
|
||||
size=size,
|
||||
spot_instances=spot_instances,
|
||||
tags=tags,
|
||||
),
|
||||
)
|
||||
|
||||
def shutdown(
|
||||
@ -1030,7 +1051,7 @@ class Scaleset(Endpoint):
|
||||
return self._req_model(
|
||||
"DELETE",
|
||||
responses.BoolResult,
|
||||
json_data={"scaleset_id": scaleset_id_expanded, "now": now},
|
||||
data=requests.ScalesetStop(scaleset_id=scaleset_id_expanded, now=now),
|
||||
)
|
||||
|
||||
def get(
|
||||
@ -1046,10 +1067,9 @@ class Scaleset(Endpoint):
|
||||
return self._req_model(
|
||||
"GET",
|
||||
models.Scaleset,
|
||||
json_data={
|
||||
"scaleset_id": scaleset_id_expanded,
|
||||
"include_auth": include_auth,
|
||||
},
|
||||
data=requests.ScalesetSearch(
|
||||
scaleset_id=scaleset_id_expanded, include_auth=include_auth
|
||||
),
|
||||
)
|
||||
|
||||
def update(
|
||||
@ -1065,7 +1085,7 @@ class Scaleset(Endpoint):
|
||||
return self._req_model(
|
||||
"PATCH",
|
||||
models.Scaleset,
|
||||
json_data={"scaleset_id": scaleset_id_expanded, "size": size},
|
||||
data=requests.ScalesetUpdate(scaleset_id=scaleset_id_expanded, size=size),
|
||||
)
|
||||
|
||||
def list(
|
||||
@ -1074,7 +1094,9 @@ class Scaleset(Endpoint):
|
||||
state: Optional[List[enums.ScalesetState]] = None,
|
||||
) -> List[models.Scaleset]:
|
||||
self.logger.debug("list scalesets")
|
||||
return self._req_model_list("GET", models.Scaleset, json_data={"state": state})
|
||||
return self._req_model_list(
|
||||
"GET", models.Scaleset, data=requests.ScalesetSearch(state=state)
|
||||
)
|
||||
|
||||
|
||||
class ScalesetProxy(Endpoint):
|
||||
@ -1105,18 +1127,18 @@ class ScalesetProxy(Endpoint):
|
||||
return self._req_model(
|
||||
"DELETE",
|
||||
responses.BoolResult,
|
||||
json_data={
|
||||
"scaleset_id": scaleset.scaleset_id,
|
||||
"machine_id": machine_id_expanded,
|
||||
"dst_port": dst_port,
|
||||
},
|
||||
data=requests.ProxyDelete(
|
||||
scaleset_id=scaleset.scaleset_id,
|
||||
machine_id=machine_id_expanded,
|
||||
dst_port=dst_port,
|
||||
),
|
||||
)
|
||||
|
||||
def reset(self, region: str) -> responses.BoolResult:
|
||||
""" Reset the proxy for an existing region """
|
||||
|
||||
return self._req_model(
|
||||
"PATCH", responses.BoolResult, json_data={"region": region}
|
||||
"PATCH", responses.BoolResult, data=requests.ProxyReset(region=region)
|
||||
)
|
||||
|
||||
def get(
|
||||
@ -1134,11 +1156,11 @@ class ScalesetProxy(Endpoint):
|
||||
proxy = self._req_model(
|
||||
"GET",
|
||||
responses.ProxyGetResult,
|
||||
json_data={
|
||||
"scaleset_id": scaleset.scaleset_id,
|
||||
"machine_id": machine_id_expanded,
|
||||
"dst_port": dst_port,
|
||||
},
|
||||
data=requests.ProxyGet(
|
||||
scaleset_id=scaleset.scaleset_id,
|
||||
machine_id=machine_id_expanded,
|
||||
dst_port=dst_port,
|
||||
),
|
||||
)
|
||||
return proxy
|
||||
|
||||
@ -1165,12 +1187,12 @@ class ScalesetProxy(Endpoint):
|
||||
return self._req_model(
|
||||
"POST",
|
||||
responses.ProxyGetResult,
|
||||
json_data={
|
||||
"scaleset_id": scaleset.scaleset_id,
|
||||
"machine_id": machine_id_expanded,
|
||||
"dst_port": dst_port,
|
||||
"duration": duration,
|
||||
},
|
||||
data=requests.ProxyCreate(
|
||||
scaleset_id=scaleset.scaleset_id,
|
||||
machine_id=machine_id_expanded,
|
||||
dst_port=dst_port,
|
||||
duration=duration,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user