Split integration tests into different steps (#1650)

Refactoring check-pr.py to extract the logic of downloading the binaries
refactoring integration-tets.py to split the logic of setup, launch, check_result and cleanup
This commit is contained in:
Cheick Keita
2022-02-22 14:33:00 -08:00
committed by GitHub
parent 5d8516bd70
commit 674444b7d7
3 changed files with 348 additions and 194 deletions

View File

@ -34,7 +34,8 @@ from onefuzz.backend import ContainerWrapper, wait
from onefuzz.cli import execute_api
from onefuzztypes.enums import OS, ContainerType, TaskState, VmState
from onefuzztypes.models import Job, Pool, Repro, Scaleset, Task
from onefuzztypes.primitives import Container, Directory, File, PoolName, Region
from onefuzztypes.primitives import (Container, Directory, File, PoolName,
Region)
from pydantic import BaseModel, Field
LINUX_POOL = "linux-test"
@ -57,6 +58,11 @@ class TemplateType(Enum):
radamsa = "radamsa"
class LaunchInfo(BaseModel):
test_id: UUID
jobs: List[UUID]
class Integration(BaseModel):
template: TemplateType
os: OS
@ -248,7 +254,7 @@ def retry(
raise Exception(f"failed '{description}'")
else:
logger.info(
f"waiting {wait_duration} seconds before retrying '{description}'"
f"waiting {wait_duration} seconds before retrying '{description}'",
)
time.sleep(wait_duration)
@ -257,7 +263,6 @@ class TestOnefuzz:
def __init__(self, onefuzz: Onefuzz, logger: logging.Logger, test_id: UUID) -> None:
self.of = onefuzz
self.logger = logger
self.pools: Dict[OS, Pool] = {}
self.test_id = test_id
self.project = f"test-{self.test_id}"
self.start_log_marker = f"integration-test-injection-error-start-{self.test_id}"
@ -279,14 +284,21 @@ class TestOnefuzz:
for entry in os_list:
name = PoolName(f"testpool-{entry.name}-{self.test_id}")
self.logger.info("creating pool: %s:%s", entry.name, name)
self.pools[entry] = self.of.pools.create(name, entry)
self.of.pools.create(name, entry)
self.logger.info("creating scaleset for pool: %s", name)
self.of.scalesets.create(name, pool_size, region=region)
def launch(
self, path: Directory, *, os_list: List[OS], targets: List[str], duration=int
) -> None:
) -> List[UUID]:
"""Launch all of the fuzzing templates"""
pools = {}
for pool in self.of.pools.list():
pools[pool.os] = pool
job_ids = []
for target, config in TARGETS.items():
if target not in targets:
continue
@ -294,6 +306,9 @@ class TestOnefuzz:
if config.os not in os_list:
continue
if config.os not in pools.keys():
raise Exception(f"No pool for target: {target} ,os: {config.os}")
self.logger.info("launching: %s", target)
setup = Directory(os.path.join(path, target)) if config.use_setup else None
@ -313,7 +328,7 @@ class TestOnefuzz:
self.project,
target,
BUILD,
self.pools[config.os].name,
pools[config.os].name,
target_exe=target_exe,
inputs=inputs,
setup_dir=setup,
@ -329,7 +344,7 @@ class TestOnefuzz:
self.project,
target,
BUILD,
self.pools[config.os].name,
pools[config.os].name,
target_harness=config.target_exe,
inputs=inputs,
setup_dir=setup,
@ -342,7 +357,7 @@ class TestOnefuzz:
self.project,
target,
BUILD,
self.pools[config.os].name,
pools[config.os].name,
inputs=inputs,
target_exe=target_exe,
duration=duration,
@ -354,7 +369,7 @@ class TestOnefuzz:
self.project,
target,
BUILD,
pool_name=self.pools[config.os].name,
pool_name=pools[config.os].name,
target_exe=target_exe,
inputs=inputs,
setup_dir=setup,
@ -368,7 +383,7 @@ class TestOnefuzz:
self.project,
target,
BUILD,
pool_name=self.pools[config.os].name,
pool_name=pools[config.os].name,
target_exe=target_exe,
inputs=inputs,
setup_dir=setup,
@ -385,6 +400,10 @@ class TestOnefuzz:
if not job:
raise Exception("missing job")
job_ids.append(job.job_id)
return job_ids
def check_task(
self, job: Job, task: Task, scalesets: List[Scaleset]
) -> TaskTestState:
@ -426,10 +445,17 @@ class TestOnefuzz:
return TaskTestState.not_running
def check_jobs(
self, poll: bool = False, stop_on_complete_check: bool = False
self,
poll: bool = False,
stop_on_complete_check: bool = False,
job_ids: List[UUID] = [],
) -> bool:
"""Check all of the integration jobs"""
jobs: Dict[UUID, Job] = {x.job_id: x for x in self.get_jobs()}
jobs: Dict[UUID, Job] = {
x.job_id: x
for x in self.get_jobs()
if (not job_ids) or (x.job_id in job_ids)
}
job_tasks: Dict[UUID, List[Task]] = {}
check_containers: Dict[UUID, Dict[Container, Tuple[ContainerWrapper, int]]] = {}
@ -576,12 +602,16 @@ class TestOnefuzz:
return (container.name, files.files[0])
return None
def launch_repro(self) -> Tuple[bool, Dict[UUID, Tuple[Job, Repro]]]:
def launch_repro(
self, job_ids: List[UUID] = []
) -> Tuple[bool, Dict[UUID, Tuple[Job, Repro]]]:
# launch repro for one report from all succeessful jobs
has_cdb = bool(which("cdb.exe"))
has_gdb = bool(which("gdb"))
jobs = self.get_jobs()
jobs = [
job for job in self.get_jobs() if (not job_ids) or (job.job_id in job_ids)
]
result = True
repros = {}
@ -646,6 +676,7 @@ class TestOnefuzz:
job.config.name,
repro.error,
)
self.success = False
self.of.repro.delete(repro.vm_id)
del repros[job.job_id]
elif repro.state == VmState.running:
@ -665,14 +696,17 @@ class TestOnefuzz:
self.logger.error(
"repro failed: %s - %s", job.config.name, result
)
self.success = False
except Exception as err:
clear()
self.logger.error("repro failed: %s - %s", job.config.name, err)
self.success = False
del repros[job.job_id]
elif repro.state not in [VmState.init, VmState.extensions_launch]:
self.logger.error(
"repro failed: %s - bad state: %s", job.config.name, repro.state
)
self.success = False
del repros[job.job_id]
repro_states: Dict[str, List[str]] = {}
@ -878,6 +912,7 @@ class Run(Command):
client_secret: Optional[str],
poll: bool = False,
stop_on_complete_check: bool = False,
job_ids: List[UUID] = [],
) -> None:
self.onefuzz.__setup__(
endpoint=endpoint,
@ -887,7 +922,7 @@ class Run(Command):
)
tester = TestOnefuzz(self.onefuzz, self.logger, test_id)
result = tester.check_jobs(
poll=poll, stop_on_complete_check=stop_on_complete_check
poll=poll, stop_on_complete_check=stop_on_complete_check, job_ids=job_ids
)
if not result:
raise Exception("jobs failed")
@ -900,6 +935,7 @@ class Run(Command):
client_id: Optional[str],
client_secret: Optional[str],
authority: Optional[str] = None,
job_ids: List[UUID] = [],
) -> None:
self.onefuzz.__setup__(
endpoint=endpoint,
@ -908,14 +944,13 @@ class Run(Command):
authority=authority,
)
tester = TestOnefuzz(self.onefuzz, self.logger, test_id)
launch_result, repros = tester.launch_repro()
launch_result, repros = tester.launch_repro(job_ids=job_ids)
result = tester.check_repro(repros)
if not (result and launch_result):
raise Exception("repros failed")
def launch(
def setup(
self,
samples: Directory,
*,
endpoint: Optional[str] = None,
authority: Optional[str] = None,
@ -924,10 +959,8 @@ class Run(Command):
pool_size: int = 10,
region: Optional[Region] = None,
os_list: List[OS] = [OS.linux, OS.windows],
targets: List[str] = list(TARGETS.keys()),
test_id: Optional[UUID] = None,
duration: int = 1,
) -> UUID:
) -> None:
if test_id is None:
test_id = uuid4()
self.logger.info("launching test_id: %s", test_id)
@ -944,8 +977,42 @@ class Run(Command):
tester = TestOnefuzz(self.onefuzz, self.logger, test_id)
tester.setup(region=region, pool_size=pool_size, os_list=os_list)
tester.launch(samples, os_list=os_list, targets=targets, duration=duration)
return test_id
def launch(
self,
samples: Directory,
*,
endpoint: Optional[str] = None,
authority: Optional[str] = None,
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
os_list: List[OS] = [OS.linux, OS.windows],
targets: List[str] = list(TARGETS.keys()),
test_id: Optional[UUID] = None,
duration: int = 1,
) -> None:
if test_id is None:
test_id = uuid4()
self.logger.info("launching test_id: %s", test_id)
def try_setup(data: Any) -> None:
self.onefuzz.__setup__(
endpoint=endpoint,
client_id=client_id,
client_secret=client_secret,
authority=authority,
)
retry(try_setup, "trying to configure")
tester = TestOnefuzz(self.onefuzz, self.logger, test_id)
job_ids = tester.launch(
samples, os_list=os_list, targets=targets, duration=duration
)
launch_data = LaunchInfo(test_id=test_id, jobs=job_ids)
print(f"launch info: {launch_data.json()}")
def cleanup(
self,
@ -983,6 +1050,41 @@ class Run(Command):
tester = TestOnefuzz(self.onefuzz, self.logger, test_id=test_id)
tester.check_logs_for_errors()
def check_results(
self,
*,
endpoint: Optional[str] = None,
authority: Optional[str] = None,
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
skip_repro: bool = False,
test_id: UUID,
job_ids: List[UUID] = [],
) -> None:
self.check_jobs(
test_id,
endpoint=endpoint,
authority=authority,
client_id=client_id,
client_secret=client_secret,
poll=True,
stop_on_complete_check=True,
job_ids=job_ids,
)
if skip_repro:
self.logger.warning("not testing crash repro")
else:
self.check_repros(
test_id,
endpoint=endpoint,
authority=authority,
client_id=client_id,
client_secret=client_secret,
job_ids=job_ids,
)
def test(
self,
samples: Directory,
@ -1003,47 +1105,31 @@ class Run(Command):
test_id = uuid4()
error: Optional[Exception] = None
try:
self.launch(
samples,
def try_setup(data: Any) -> None:
self.onefuzz.__setup__(
endpoint=endpoint,
authority=authority,
client_id=client_id,
client_secret=client_secret,
pool_size=pool_size,
region=region,
os_list=os_list,
targets=targets,
test_id=test_id,
duration=duration,
)
self.check_jobs(
test_id,
endpoint=endpoint,
authority=authority,
client_id=client_id,
client_secret=client_secret,
poll=True,
stop_on_complete_check=True,
)
retry(try_setup, "trying to configure")
tester = TestOnefuzz(self.onefuzz, self.logger, test_id)
tester.setup(region=region, pool_size=pool_size, os_list=os_list)
tester.launch(samples, os_list=os_list, targets=targets, duration=duration)
result = tester.check_jobs(poll=True, stop_on_complete_check=True)
if not result:
raise Exception("jobs failed")
if skip_repro:
self.logger.warning("not testing crash repro")
else:
self.check_repros(
test_id,
endpoint=endpoint,
authority=authority,
client_id=client_id,
client_secret=client_secret,
)
launch_result, repros = tester.launch_repro()
result = tester.check_repro(repros)
if not (result and launch_result):
raise Exception("repros failed")
self.check_logs(
test_id,
endpoint=endpoint,
client_id=client_id,
client_secret=client_secret,
authority=authority,
)
tester.check_logs_for_errors()
except Exception as e:
self.logger.error("testing failed: %s", repr(e))

View File

@ -6,146 +6,14 @@
import argparse
import os
import subprocess
import sys
import tempfile
import time
import uuid
from typing import Callable, List, Optional, Tuple, TypeVar
import requests
from github import Github
from typing import List, Optional
from cleanup_ad import delete_current_user_app_registrations
A = TypeVar("A")
def wait(func: Callable[[], Tuple[bool, str, A]], frequency: float = 1.0) -> A:
"""
Wait until the provided func returns True.
Provides user feedback via a spinner if stdout is a TTY.
"""
isatty = sys.stdout.isatty()
frames = ["-", "\\", "|", "/"]
waited = False
last_message = None
result = None
try:
while True:
result = func()
if result[0]:
break
message = result[1]
if isatty:
if last_message:
if last_message == message:
sys.stdout.write("\b" * (len(last_message) + 2))
else:
sys.stdout.write("\n")
sys.stdout.write("%s %s" % (frames[0], message))
sys.stdout.flush()
elif last_message != message:
print(message, flush=True)
last_message = message
waited = True
time.sleep(frequency)
frames.sort(key=frames[0].__eq__)
finally:
if waited and isatty:
print(flush=True)
return result[2]
class Downloader:
def __init__(self) -> None:
self.gh = Github(login_or_token=os.environ["GITHUB_ISSUE_TOKEN"])
def update_pr(self, repo_name: str, pr: int) -> None:
pr_obj = self.gh.get_repo(repo_name).get_pull(pr)
if pr_obj.mergeable_state == "behind":
print(f"pr:{pr} out of date. Updating")
pr_obj.update_branch() # type: ignore
time.sleep(5)
elif pr_obj.mergeable_state == "dirty":
raise Exception(f"merge confict errors on pr:{pr}")
def merge_pr(self, repo_name: str, pr: int) -> None:
pr_obj = self.gh.get_repo(repo_name).get_pull(pr)
if pr_obj.mergeable_state == "clean":
print(f"merging pr:{pr}")
pr_obj.merge(commit_message="", merge_method="squash")
else:
print(f"unable to merge pr:{pr}", pr_obj.mergeable_state)
def get_artifact(
self,
repo_name: str,
workflow: str,
branch: Optional[str],
pr: Optional[int],
name: str,
filename: str,
) -> None:
print(f"getting {name}")
if pr:
self.update_pr(repo_name, pr)
branch = self.gh.get_repo(repo_name).get_pull(pr).head.ref
if not branch:
raise Exception("missing branch")
zip_file_url = self.get_artifact_url(repo_name, workflow, branch, name)
(code, resp, _) = self.gh._Github__requester.requestBlob( # type: ignore
"GET", zip_file_url, {}
)
if code != 302:
raise Exception(f"unexpected response: {resp}")
with open(filename, "wb") as handle:
for chunk in requests.get(resp["location"], stream=True).iter_content(
chunk_size=1024 * 16
):
handle.write(chunk)
def get_artifact_url(
self, repo_name: str, workflow_name: str, branch: str, name: str
) -> str:
repo = self.gh.get_repo(repo_name)
workflow = repo.get_workflow(workflow_name)
runs = workflow.get_runs()
run = None
for x in runs:
if x.head_branch != branch:
continue
run = x
break
if not run:
raise Exception("invalid run")
print("using run from branch", run.head_branch)
def check() -> Tuple[bool, str, None]:
if run is None:
raise Exception("invalid run")
run.update()
return run.status == "completed", run.status, None
wait(check, frequency=10.0)
if run.conclusion != "success":
raise Exception(f"bad conclusion: {run.conclusion}")
response = requests.get(run.artifacts_url).json()
for artifact in response["artifacts"]:
if artifact["name"] == name:
return str(artifact["archive_download_url"])
raise Exception(f"no archive url for {branch} - {name}")
from .github_client import GithubClient
def venv_path(base: str, name: str) -> str:
@ -173,7 +41,7 @@ class Deployer:
repo: str,
unattended: bool,
):
self.downloader = Downloader()
self.downloader = GithubClient()
self.pr = pr
self.branch = branch
self.instance = instance

View File

@ -0,0 +1,200 @@
import argparse
import os
import sys
import time
from pathlib import Path
from typing import Callable, Optional, Tuple, TypeVar
import requests
from github import Github
A = TypeVar("A")
def wait(func: Callable[[], Tuple[bool, str, A]], frequency: float = 1.0) -> A:
"""
Wait until the provided func returns True.
Provides user feedback via a spinner if stdout is a TTY.
"""
isatty = sys.stdout.isatty()
frames = ["-", "\\", "|", "/"]
waited = False
last_message = None
result = None
try:
while True:
result = func()
if result[0]:
break
message = result[1]
if isatty:
if last_message:
if last_message == message:
sys.stdout.write("\b" * (len(last_message) + 2))
else:
sys.stdout.write("\n")
sys.stdout.write("%s %s" % (frames[0], message))
sys.stdout.flush()
elif last_message != message:
print(message, flush=True)
last_message = message
waited = True
time.sleep(frequency)
frames.sort(key=frames[0].__eq__)
finally:
if waited and isatty:
print(flush=True)
return result[2]
class GithubClient:
def __init__(self) -> None:
self.gh = Github(login_or_token=os.environ["GITHUB_ISSUE_TOKEN"])
def update_pr(self, repo_name: str, pr: int, update_if_behind: bool = True) -> None:
pr_obj = self.gh.get_repo(repo_name).get_pull(pr)
if (pr_obj.mergeable_state == "behind") and update_if_behind:
print(f"pr:{pr} out of date. Updating")
pr_obj.update_branch() # type: ignore
time.sleep(5)
elif pr_obj.mergeable_state == "dirty":
raise Exception(f"merge confict errors on pr:{pr}")
def merge_pr(self, repo_name: str, pr: int) -> None:
pr_obj = self.gh.get_repo(repo_name).get_pull(pr)
if pr_obj.mergeable_state == "clean":
print(f"merging pr:{pr}")
pr_obj.merge(commit_message="", merge_method="squash")
else:
print(f"unable to merge pr:{pr}", pr_obj.mergeable_state)
def get_artifact(
self,
repo_name: str,
workflow: str,
branch: Optional[str],
pr: Optional[int],
name: str,
file_path: str,
update_branch: bool = True,
) -> None:
print(f"getting {name}")
if pr:
self.update_pr(repo_name, pr, update_branch)
branch = self.gh.get_repo(repo_name).get_pull(pr).head.ref
if not branch:
raise Exception("missing branch")
zip_file_url = self.get_artifact_url(repo_name, workflow, branch, name)
(code, resp, _) = self.gh._Github__requester.requestBlob( # type: ignore
"GET", zip_file_url, {}
)
if code != 302:
raise Exception(f"unexpected response: {resp}")
with open(file_path, "wb") as handle:
print(f"writing {file_path}")
for chunk in requests.get(resp["location"], stream=True).iter_content(
chunk_size=1024 * 16
):
handle.write(chunk)
def get_artifact_url(
self, repo_name: str, workflow_name: str, branch: str, name: str
) -> str:
repo = self.gh.get_repo(repo_name)
workflow = repo.get_workflow(workflow_name)
runs = workflow.get_runs()
run = None
for x in runs:
if x.head_branch != branch:
continue
run = x
break
if not run:
raise Exception("invalid run")
print("using run from branch", run.head_branch)
def check() -> Tuple[bool, str, None]:
if run is None:
raise Exception("invalid run")
run.update()
return run.status == "completed", run.status, None
wait(check, frequency=10.0)
if run.conclusion != "success":
raise Exception(f"bad conclusion: {run.conclusion}")
response = requests.get(run.artifacts_url).json()
for artifact in response["artifacts"]:
if artifact["name"] == name:
return str(artifact["archive_download_url"])
raise Exception(f"no archive url for {branch} - {name}")
def download_artifacts(
downloader: GithubClient,
repo: str,
branch: Optional[str],
pr: Optional[int],
directory: str,
update_branch: bool = True,
) -> None:
release_filename = "release-artifacts.zip"
downloader.get_artifact(
repo,
"ci.yml",
branch,
pr,
"release-artifacts",
os.path.join(directory, release_filename),
)
test_filename = "integration-test-artifacts.zip"
downloader.get_artifact(
repo,
"ci.yml",
branch,
pr,
"integration-test-artifacts",
os.path.join(directory, test_filename),
)
def main() -> None:
parser = argparse.ArgumentParser()
group = parser.add_mutually_exclusive_group()
group.add_argument("--branch")
group.add_argument("--pr", type=int)
parser.add_argument("--repo", default="microsoft/onefuzz")
parser.add_argument("--destination", default=os.getcwd())
parser.add_argument("--skip_update", action="store_true")
args = parser.parse_args()
path = Path(args.destination)
path.mkdir(parents=True, exist_ok=True)
downloader = GithubClient()
download_artifacts(
downloader,
args.repo,
args.branch,
args.pr,
args.destination,
not args.skip_update,
)
if __name__ == "__main__":
main()