mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-18 12:48:07 +00:00
Split integration tests into different steps (#1650)
Refactoring check-pr.py to extract the logic of downloading the binaries refactoring integration-tets.py to split the logic of setup, launch, check_result and cleanup
This commit is contained in:
@ -34,7 +34,8 @@ from onefuzz.backend import ContainerWrapper, wait
|
|||||||
from onefuzz.cli import execute_api
|
from onefuzz.cli import execute_api
|
||||||
from onefuzztypes.enums import OS, ContainerType, TaskState, VmState
|
from onefuzztypes.enums import OS, ContainerType, TaskState, VmState
|
||||||
from onefuzztypes.models import Job, Pool, Repro, Scaleset, Task
|
from onefuzztypes.models import Job, Pool, Repro, Scaleset, Task
|
||||||
from onefuzztypes.primitives import Container, Directory, File, PoolName, Region
|
from onefuzztypes.primitives import (Container, Directory, File, PoolName,
|
||||||
|
Region)
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
LINUX_POOL = "linux-test"
|
LINUX_POOL = "linux-test"
|
||||||
@ -57,6 +58,11 @@ class TemplateType(Enum):
|
|||||||
radamsa = "radamsa"
|
radamsa = "radamsa"
|
||||||
|
|
||||||
|
|
||||||
|
class LaunchInfo(BaseModel):
|
||||||
|
test_id: UUID
|
||||||
|
jobs: List[UUID]
|
||||||
|
|
||||||
|
|
||||||
class Integration(BaseModel):
|
class Integration(BaseModel):
|
||||||
template: TemplateType
|
template: TemplateType
|
||||||
os: OS
|
os: OS
|
||||||
@ -248,7 +254,7 @@ def retry(
|
|||||||
raise Exception(f"failed '{description}'")
|
raise Exception(f"failed '{description}'")
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"waiting {wait_duration} seconds before retrying '{description}'"
|
f"waiting {wait_duration} seconds before retrying '{description}'",
|
||||||
)
|
)
|
||||||
time.sleep(wait_duration)
|
time.sleep(wait_duration)
|
||||||
|
|
||||||
@ -257,7 +263,6 @@ class TestOnefuzz:
|
|||||||
def __init__(self, onefuzz: Onefuzz, logger: logging.Logger, test_id: UUID) -> None:
|
def __init__(self, onefuzz: Onefuzz, logger: logging.Logger, test_id: UUID) -> None:
|
||||||
self.of = onefuzz
|
self.of = onefuzz
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.pools: Dict[OS, Pool] = {}
|
|
||||||
self.test_id = test_id
|
self.test_id = test_id
|
||||||
self.project = f"test-{self.test_id}"
|
self.project = f"test-{self.test_id}"
|
||||||
self.start_log_marker = f"integration-test-injection-error-start-{self.test_id}"
|
self.start_log_marker = f"integration-test-injection-error-start-{self.test_id}"
|
||||||
@ -279,14 +284,21 @@ class TestOnefuzz:
|
|||||||
for entry in os_list:
|
for entry in os_list:
|
||||||
name = PoolName(f"testpool-{entry.name}-{self.test_id}")
|
name = PoolName(f"testpool-{entry.name}-{self.test_id}")
|
||||||
self.logger.info("creating pool: %s:%s", entry.name, name)
|
self.logger.info("creating pool: %s:%s", entry.name, name)
|
||||||
self.pools[entry] = self.of.pools.create(name, entry)
|
self.of.pools.create(name, entry)
|
||||||
self.logger.info("creating scaleset for pool: %s", name)
|
self.logger.info("creating scaleset for pool: %s", name)
|
||||||
self.of.scalesets.create(name, pool_size, region=region)
|
self.of.scalesets.create(name, pool_size, region=region)
|
||||||
|
|
||||||
def launch(
|
def launch(
|
||||||
self, path: Directory, *, os_list: List[OS], targets: List[str], duration=int
|
self, path: Directory, *, os_list: List[OS], targets: List[str], duration=int
|
||||||
) -> None:
|
) -> List[UUID]:
|
||||||
"""Launch all of the fuzzing templates"""
|
"""Launch all of the fuzzing templates"""
|
||||||
|
|
||||||
|
pools = {}
|
||||||
|
for pool in self.of.pools.list():
|
||||||
|
pools[pool.os] = pool
|
||||||
|
|
||||||
|
job_ids = []
|
||||||
|
|
||||||
for target, config in TARGETS.items():
|
for target, config in TARGETS.items():
|
||||||
if target not in targets:
|
if target not in targets:
|
||||||
continue
|
continue
|
||||||
@ -294,6 +306,9 @@ class TestOnefuzz:
|
|||||||
if config.os not in os_list:
|
if config.os not in os_list:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if config.os not in pools.keys():
|
||||||
|
raise Exception(f"No pool for target: {target} ,os: {config.os}")
|
||||||
|
|
||||||
self.logger.info("launching: %s", target)
|
self.logger.info("launching: %s", target)
|
||||||
|
|
||||||
setup = Directory(os.path.join(path, target)) if config.use_setup else None
|
setup = Directory(os.path.join(path, target)) if config.use_setup else None
|
||||||
@ -313,7 +328,7 @@ class TestOnefuzz:
|
|||||||
self.project,
|
self.project,
|
||||||
target,
|
target,
|
||||||
BUILD,
|
BUILD,
|
||||||
self.pools[config.os].name,
|
pools[config.os].name,
|
||||||
target_exe=target_exe,
|
target_exe=target_exe,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
setup_dir=setup,
|
setup_dir=setup,
|
||||||
@ -329,7 +344,7 @@ class TestOnefuzz:
|
|||||||
self.project,
|
self.project,
|
||||||
target,
|
target,
|
||||||
BUILD,
|
BUILD,
|
||||||
self.pools[config.os].name,
|
pools[config.os].name,
|
||||||
target_harness=config.target_exe,
|
target_harness=config.target_exe,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
setup_dir=setup,
|
setup_dir=setup,
|
||||||
@ -342,7 +357,7 @@ class TestOnefuzz:
|
|||||||
self.project,
|
self.project,
|
||||||
target,
|
target,
|
||||||
BUILD,
|
BUILD,
|
||||||
self.pools[config.os].name,
|
pools[config.os].name,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
target_exe=target_exe,
|
target_exe=target_exe,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
@ -354,7 +369,7 @@ class TestOnefuzz:
|
|||||||
self.project,
|
self.project,
|
||||||
target,
|
target,
|
||||||
BUILD,
|
BUILD,
|
||||||
pool_name=self.pools[config.os].name,
|
pool_name=pools[config.os].name,
|
||||||
target_exe=target_exe,
|
target_exe=target_exe,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
setup_dir=setup,
|
setup_dir=setup,
|
||||||
@ -368,7 +383,7 @@ class TestOnefuzz:
|
|||||||
self.project,
|
self.project,
|
||||||
target,
|
target,
|
||||||
BUILD,
|
BUILD,
|
||||||
pool_name=self.pools[config.os].name,
|
pool_name=pools[config.os].name,
|
||||||
target_exe=target_exe,
|
target_exe=target_exe,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
setup_dir=setup,
|
setup_dir=setup,
|
||||||
@ -385,6 +400,10 @@ class TestOnefuzz:
|
|||||||
if not job:
|
if not job:
|
||||||
raise Exception("missing job")
|
raise Exception("missing job")
|
||||||
|
|
||||||
|
job_ids.append(job.job_id)
|
||||||
|
|
||||||
|
return job_ids
|
||||||
|
|
||||||
def check_task(
|
def check_task(
|
||||||
self, job: Job, task: Task, scalesets: List[Scaleset]
|
self, job: Job, task: Task, scalesets: List[Scaleset]
|
||||||
) -> TaskTestState:
|
) -> TaskTestState:
|
||||||
@ -426,10 +445,17 @@ class TestOnefuzz:
|
|||||||
return TaskTestState.not_running
|
return TaskTestState.not_running
|
||||||
|
|
||||||
def check_jobs(
|
def check_jobs(
|
||||||
self, poll: bool = False, stop_on_complete_check: bool = False
|
self,
|
||||||
|
poll: bool = False,
|
||||||
|
stop_on_complete_check: bool = False,
|
||||||
|
job_ids: List[UUID] = [],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check all of the integration jobs"""
|
"""Check all of the integration jobs"""
|
||||||
jobs: Dict[UUID, Job] = {x.job_id: x for x in self.get_jobs()}
|
jobs: Dict[UUID, Job] = {
|
||||||
|
x.job_id: x
|
||||||
|
for x in self.get_jobs()
|
||||||
|
if (not job_ids) or (x.job_id in job_ids)
|
||||||
|
}
|
||||||
job_tasks: Dict[UUID, List[Task]] = {}
|
job_tasks: Dict[UUID, List[Task]] = {}
|
||||||
check_containers: Dict[UUID, Dict[Container, Tuple[ContainerWrapper, int]]] = {}
|
check_containers: Dict[UUID, Dict[Container, Tuple[ContainerWrapper, int]]] = {}
|
||||||
|
|
||||||
@ -576,12 +602,16 @@ class TestOnefuzz:
|
|||||||
return (container.name, files.files[0])
|
return (container.name, files.files[0])
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def launch_repro(self) -> Tuple[bool, Dict[UUID, Tuple[Job, Repro]]]:
|
def launch_repro(
|
||||||
|
self, job_ids: List[UUID] = []
|
||||||
|
) -> Tuple[bool, Dict[UUID, Tuple[Job, Repro]]]:
|
||||||
# launch repro for one report from all succeessful jobs
|
# launch repro for one report from all succeessful jobs
|
||||||
has_cdb = bool(which("cdb.exe"))
|
has_cdb = bool(which("cdb.exe"))
|
||||||
has_gdb = bool(which("gdb"))
|
has_gdb = bool(which("gdb"))
|
||||||
|
|
||||||
jobs = self.get_jobs()
|
jobs = [
|
||||||
|
job for job in self.get_jobs() if (not job_ids) or (job.job_id in job_ids)
|
||||||
|
]
|
||||||
|
|
||||||
result = True
|
result = True
|
||||||
repros = {}
|
repros = {}
|
||||||
@ -646,6 +676,7 @@ class TestOnefuzz:
|
|||||||
job.config.name,
|
job.config.name,
|
||||||
repro.error,
|
repro.error,
|
||||||
)
|
)
|
||||||
|
self.success = False
|
||||||
self.of.repro.delete(repro.vm_id)
|
self.of.repro.delete(repro.vm_id)
|
||||||
del repros[job.job_id]
|
del repros[job.job_id]
|
||||||
elif repro.state == VmState.running:
|
elif repro.state == VmState.running:
|
||||||
@ -665,14 +696,17 @@ class TestOnefuzz:
|
|||||||
self.logger.error(
|
self.logger.error(
|
||||||
"repro failed: %s - %s", job.config.name, result
|
"repro failed: %s - %s", job.config.name, result
|
||||||
)
|
)
|
||||||
|
self.success = False
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
clear()
|
clear()
|
||||||
self.logger.error("repro failed: %s - %s", job.config.name, err)
|
self.logger.error("repro failed: %s - %s", job.config.name, err)
|
||||||
|
self.success = False
|
||||||
del repros[job.job_id]
|
del repros[job.job_id]
|
||||||
elif repro.state not in [VmState.init, VmState.extensions_launch]:
|
elif repro.state not in [VmState.init, VmState.extensions_launch]:
|
||||||
self.logger.error(
|
self.logger.error(
|
||||||
"repro failed: %s - bad state: %s", job.config.name, repro.state
|
"repro failed: %s - bad state: %s", job.config.name, repro.state
|
||||||
)
|
)
|
||||||
|
self.success = False
|
||||||
del repros[job.job_id]
|
del repros[job.job_id]
|
||||||
|
|
||||||
repro_states: Dict[str, List[str]] = {}
|
repro_states: Dict[str, List[str]] = {}
|
||||||
@ -878,6 +912,7 @@ class Run(Command):
|
|||||||
client_secret: Optional[str],
|
client_secret: Optional[str],
|
||||||
poll: bool = False,
|
poll: bool = False,
|
||||||
stop_on_complete_check: bool = False,
|
stop_on_complete_check: bool = False,
|
||||||
|
job_ids: List[UUID] = [],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.onefuzz.__setup__(
|
self.onefuzz.__setup__(
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
@ -887,7 +922,7 @@ class Run(Command):
|
|||||||
)
|
)
|
||||||
tester = TestOnefuzz(self.onefuzz, self.logger, test_id)
|
tester = TestOnefuzz(self.onefuzz, self.logger, test_id)
|
||||||
result = tester.check_jobs(
|
result = tester.check_jobs(
|
||||||
poll=poll, stop_on_complete_check=stop_on_complete_check
|
poll=poll, stop_on_complete_check=stop_on_complete_check, job_ids=job_ids
|
||||||
)
|
)
|
||||||
if not result:
|
if not result:
|
||||||
raise Exception("jobs failed")
|
raise Exception("jobs failed")
|
||||||
@ -900,6 +935,7 @@ class Run(Command):
|
|||||||
client_id: Optional[str],
|
client_id: Optional[str],
|
||||||
client_secret: Optional[str],
|
client_secret: Optional[str],
|
||||||
authority: Optional[str] = None,
|
authority: Optional[str] = None,
|
||||||
|
job_ids: List[UUID] = [],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.onefuzz.__setup__(
|
self.onefuzz.__setup__(
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
@ -908,14 +944,13 @@ class Run(Command):
|
|||||||
authority=authority,
|
authority=authority,
|
||||||
)
|
)
|
||||||
tester = TestOnefuzz(self.onefuzz, self.logger, test_id)
|
tester = TestOnefuzz(self.onefuzz, self.logger, test_id)
|
||||||
launch_result, repros = tester.launch_repro()
|
launch_result, repros = tester.launch_repro(job_ids=job_ids)
|
||||||
result = tester.check_repro(repros)
|
result = tester.check_repro(repros)
|
||||||
if not (result and launch_result):
|
if not (result and launch_result):
|
||||||
raise Exception("repros failed")
|
raise Exception("repros failed")
|
||||||
|
|
||||||
def launch(
|
def setup(
|
||||||
self,
|
self,
|
||||||
samples: Directory,
|
|
||||||
*,
|
*,
|
||||||
endpoint: Optional[str] = None,
|
endpoint: Optional[str] = None,
|
||||||
authority: Optional[str] = None,
|
authority: Optional[str] = None,
|
||||||
@ -924,10 +959,8 @@ class Run(Command):
|
|||||||
pool_size: int = 10,
|
pool_size: int = 10,
|
||||||
region: Optional[Region] = None,
|
region: Optional[Region] = None,
|
||||||
os_list: List[OS] = [OS.linux, OS.windows],
|
os_list: List[OS] = [OS.linux, OS.windows],
|
||||||
targets: List[str] = list(TARGETS.keys()),
|
|
||||||
test_id: Optional[UUID] = None,
|
test_id: Optional[UUID] = None,
|
||||||
duration: int = 1,
|
) -> None:
|
||||||
) -> UUID:
|
|
||||||
if test_id is None:
|
if test_id is None:
|
||||||
test_id = uuid4()
|
test_id = uuid4()
|
||||||
self.logger.info("launching test_id: %s", test_id)
|
self.logger.info("launching test_id: %s", test_id)
|
||||||
@ -944,8 +977,42 @@ class Run(Command):
|
|||||||
|
|
||||||
tester = TestOnefuzz(self.onefuzz, self.logger, test_id)
|
tester = TestOnefuzz(self.onefuzz, self.logger, test_id)
|
||||||
tester.setup(region=region, pool_size=pool_size, os_list=os_list)
|
tester.setup(region=region, pool_size=pool_size, os_list=os_list)
|
||||||
tester.launch(samples, os_list=os_list, targets=targets, duration=duration)
|
|
||||||
return test_id
|
def launch(
|
||||||
|
self,
|
||||||
|
samples: Directory,
|
||||||
|
*,
|
||||||
|
endpoint: Optional[str] = None,
|
||||||
|
authority: Optional[str] = None,
|
||||||
|
client_id: Optional[str] = None,
|
||||||
|
client_secret: Optional[str] = None,
|
||||||
|
os_list: List[OS] = [OS.linux, OS.windows],
|
||||||
|
targets: List[str] = list(TARGETS.keys()),
|
||||||
|
test_id: Optional[UUID] = None,
|
||||||
|
duration: int = 1,
|
||||||
|
) -> None:
|
||||||
|
if test_id is None:
|
||||||
|
test_id = uuid4()
|
||||||
|
self.logger.info("launching test_id: %s", test_id)
|
||||||
|
|
||||||
|
def try_setup(data: Any) -> None:
|
||||||
|
self.onefuzz.__setup__(
|
||||||
|
endpoint=endpoint,
|
||||||
|
client_id=client_id,
|
||||||
|
client_secret=client_secret,
|
||||||
|
authority=authority,
|
||||||
|
)
|
||||||
|
|
||||||
|
retry(try_setup, "trying to configure")
|
||||||
|
|
||||||
|
tester = TestOnefuzz(self.onefuzz, self.logger, test_id)
|
||||||
|
|
||||||
|
job_ids = tester.launch(
|
||||||
|
samples, os_list=os_list, targets=targets, duration=duration
|
||||||
|
)
|
||||||
|
launch_data = LaunchInfo(test_id=test_id, jobs=job_ids)
|
||||||
|
|
||||||
|
print(f"launch info: {launch_data.json()}")
|
||||||
|
|
||||||
def cleanup(
|
def cleanup(
|
||||||
self,
|
self,
|
||||||
@ -983,6 +1050,41 @@ class Run(Command):
|
|||||||
tester = TestOnefuzz(self.onefuzz, self.logger, test_id=test_id)
|
tester = TestOnefuzz(self.onefuzz, self.logger, test_id=test_id)
|
||||||
tester.check_logs_for_errors()
|
tester.check_logs_for_errors()
|
||||||
|
|
||||||
|
def check_results(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
endpoint: Optional[str] = None,
|
||||||
|
authority: Optional[str] = None,
|
||||||
|
client_id: Optional[str] = None,
|
||||||
|
client_secret: Optional[str] = None,
|
||||||
|
skip_repro: bool = False,
|
||||||
|
test_id: UUID,
|
||||||
|
job_ids: List[UUID] = [],
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
self.check_jobs(
|
||||||
|
test_id,
|
||||||
|
endpoint=endpoint,
|
||||||
|
authority=authority,
|
||||||
|
client_id=client_id,
|
||||||
|
client_secret=client_secret,
|
||||||
|
poll=True,
|
||||||
|
stop_on_complete_check=True,
|
||||||
|
job_ids=job_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
if skip_repro:
|
||||||
|
self.logger.warning("not testing crash repro")
|
||||||
|
else:
|
||||||
|
self.check_repros(
|
||||||
|
test_id,
|
||||||
|
endpoint=endpoint,
|
||||||
|
authority=authority,
|
||||||
|
client_id=client_id,
|
||||||
|
client_secret=client_secret,
|
||||||
|
job_ids=job_ids,
|
||||||
|
)
|
||||||
|
|
||||||
def test(
|
def test(
|
||||||
self,
|
self,
|
||||||
samples: Directory,
|
samples: Directory,
|
||||||
@ -1003,47 +1105,31 @@ class Run(Command):
|
|||||||
test_id = uuid4()
|
test_id = uuid4()
|
||||||
error: Optional[Exception] = None
|
error: Optional[Exception] = None
|
||||||
try:
|
try:
|
||||||
self.launch(
|
|
||||||
samples,
|
|
||||||
endpoint=endpoint,
|
|
||||||
authority=authority,
|
|
||||||
client_id=client_id,
|
|
||||||
client_secret=client_secret,
|
|
||||||
pool_size=pool_size,
|
|
||||||
region=region,
|
|
||||||
os_list=os_list,
|
|
||||||
targets=targets,
|
|
||||||
test_id=test_id,
|
|
||||||
duration=duration,
|
|
||||||
)
|
|
||||||
self.check_jobs(
|
|
||||||
test_id,
|
|
||||||
endpoint=endpoint,
|
|
||||||
authority=authority,
|
|
||||||
client_id=client_id,
|
|
||||||
client_secret=client_secret,
|
|
||||||
poll=True,
|
|
||||||
stop_on_complete_check=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
def try_setup(data: Any) -> None:
|
||||||
|
self.onefuzz.__setup__(
|
||||||
|
endpoint=endpoint,
|
||||||
|
client_id=client_id,
|
||||||
|
client_secret=client_secret,
|
||||||
|
authority=authority,
|
||||||
|
)
|
||||||
|
|
||||||
|
retry(try_setup, "trying to configure")
|
||||||
|
tester = TestOnefuzz(self.onefuzz, self.logger, test_id)
|
||||||
|
tester.setup(region=region, pool_size=pool_size, os_list=os_list)
|
||||||
|
tester.launch(samples, os_list=os_list, targets=targets, duration=duration)
|
||||||
|
result = tester.check_jobs(poll=True, stop_on_complete_check=True)
|
||||||
|
if not result:
|
||||||
|
raise Exception("jobs failed")
|
||||||
if skip_repro:
|
if skip_repro:
|
||||||
self.logger.warning("not testing crash repro")
|
self.logger.warning("not testing crash repro")
|
||||||
else:
|
else:
|
||||||
self.check_repros(
|
launch_result, repros = tester.launch_repro()
|
||||||
test_id,
|
result = tester.check_repro(repros)
|
||||||
endpoint=endpoint,
|
if not (result and launch_result):
|
||||||
authority=authority,
|
raise Exception("repros failed")
|
||||||
client_id=client_id,
|
|
||||||
client_secret=client_secret,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.check_logs(
|
tester.check_logs_for_errors()
|
||||||
test_id,
|
|
||||||
endpoint=endpoint,
|
|
||||||
client_id=client_id,
|
|
||||||
client_secret=client_secret,
|
|
||||||
authority=authority,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error("testing failed: %s", repr(e))
|
self.logger.error("testing failed: %s", repr(e))
|
||||||
|
@ -6,146 +6,14 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Callable, List, Optional, Tuple, TypeVar
|
from typing import List, Optional
|
||||||
|
|
||||||
import requests
|
|
||||||
from github import Github
|
|
||||||
|
|
||||||
from cleanup_ad import delete_current_user_app_registrations
|
from cleanup_ad import delete_current_user_app_registrations
|
||||||
|
|
||||||
A = TypeVar("A")
|
from .github_client import GithubClient
|
||||||
|
|
||||||
|
|
||||||
def wait(func: Callable[[], Tuple[bool, str, A]], frequency: float = 1.0) -> A:
|
|
||||||
"""
|
|
||||||
Wait until the provided func returns True.
|
|
||||||
|
|
||||||
Provides user feedback via a spinner if stdout is a TTY.
|
|
||||||
"""
|
|
||||||
|
|
||||||
isatty = sys.stdout.isatty()
|
|
||||||
frames = ["-", "\\", "|", "/"]
|
|
||||||
waited = False
|
|
||||||
last_message = None
|
|
||||||
result = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
result = func()
|
|
||||||
if result[0]:
|
|
||||||
break
|
|
||||||
message = result[1]
|
|
||||||
|
|
||||||
if isatty:
|
|
||||||
if last_message:
|
|
||||||
if last_message == message:
|
|
||||||
sys.stdout.write("\b" * (len(last_message) + 2))
|
|
||||||
else:
|
|
||||||
sys.stdout.write("\n")
|
|
||||||
sys.stdout.write("%s %s" % (frames[0], message))
|
|
||||||
sys.stdout.flush()
|
|
||||||
elif last_message != message:
|
|
||||||
print(message, flush=True)
|
|
||||||
|
|
||||||
last_message = message
|
|
||||||
waited = True
|
|
||||||
time.sleep(frequency)
|
|
||||||
frames.sort(key=frames[0].__eq__)
|
|
||||||
finally:
|
|
||||||
if waited and isatty:
|
|
||||||
print(flush=True)
|
|
||||||
|
|
||||||
return result[2]
|
|
||||||
|
|
||||||
|
|
||||||
class Downloader:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.gh = Github(login_or_token=os.environ["GITHUB_ISSUE_TOKEN"])
|
|
||||||
|
|
||||||
def update_pr(self, repo_name: str, pr: int) -> None:
|
|
||||||
pr_obj = self.gh.get_repo(repo_name).get_pull(pr)
|
|
||||||
if pr_obj.mergeable_state == "behind":
|
|
||||||
print(f"pr:{pr} out of date. Updating")
|
|
||||||
pr_obj.update_branch() # type: ignore
|
|
||||||
time.sleep(5)
|
|
||||||
elif pr_obj.mergeable_state == "dirty":
|
|
||||||
raise Exception(f"merge confict errors on pr:{pr}")
|
|
||||||
|
|
||||||
def merge_pr(self, repo_name: str, pr: int) -> None:
|
|
||||||
pr_obj = self.gh.get_repo(repo_name).get_pull(pr)
|
|
||||||
if pr_obj.mergeable_state == "clean":
|
|
||||||
print(f"merging pr:{pr}")
|
|
||||||
pr_obj.merge(commit_message="", merge_method="squash")
|
|
||||||
else:
|
|
||||||
print(f"unable to merge pr:{pr}", pr_obj.mergeable_state)
|
|
||||||
|
|
||||||
def get_artifact(
|
|
||||||
self,
|
|
||||||
repo_name: str,
|
|
||||||
workflow: str,
|
|
||||||
branch: Optional[str],
|
|
||||||
pr: Optional[int],
|
|
||||||
name: str,
|
|
||||||
filename: str,
|
|
||||||
) -> None:
|
|
||||||
print(f"getting {name}")
|
|
||||||
|
|
||||||
if pr:
|
|
||||||
self.update_pr(repo_name, pr)
|
|
||||||
branch = self.gh.get_repo(repo_name).get_pull(pr).head.ref
|
|
||||||
if not branch:
|
|
||||||
raise Exception("missing branch")
|
|
||||||
|
|
||||||
zip_file_url = self.get_artifact_url(repo_name, workflow, branch, name)
|
|
||||||
|
|
||||||
(code, resp, _) = self.gh._Github__requester.requestBlob( # type: ignore
|
|
||||||
"GET", zip_file_url, {}
|
|
||||||
)
|
|
||||||
if code != 302:
|
|
||||||
raise Exception(f"unexpected response: {resp}")
|
|
||||||
|
|
||||||
with open(filename, "wb") as handle:
|
|
||||||
for chunk in requests.get(resp["location"], stream=True).iter_content(
|
|
||||||
chunk_size=1024 * 16
|
|
||||||
):
|
|
||||||
handle.write(chunk)
|
|
||||||
|
|
||||||
def get_artifact_url(
|
|
||||||
self, repo_name: str, workflow_name: str, branch: str, name: str
|
|
||||||
) -> str:
|
|
||||||
repo = self.gh.get_repo(repo_name)
|
|
||||||
workflow = repo.get_workflow(workflow_name)
|
|
||||||
runs = workflow.get_runs()
|
|
||||||
run = None
|
|
||||||
for x in runs:
|
|
||||||
if x.head_branch != branch:
|
|
||||||
continue
|
|
||||||
run = x
|
|
||||||
break
|
|
||||||
if not run:
|
|
||||||
raise Exception("invalid run")
|
|
||||||
|
|
||||||
print("using run from branch", run.head_branch)
|
|
||||||
|
|
||||||
def check() -> Tuple[bool, str, None]:
|
|
||||||
if run is None:
|
|
||||||
raise Exception("invalid run")
|
|
||||||
run.update()
|
|
||||||
return run.status == "completed", run.status, None
|
|
||||||
|
|
||||||
wait(check, frequency=10.0)
|
|
||||||
if run.conclusion != "success":
|
|
||||||
raise Exception(f"bad conclusion: {run.conclusion}")
|
|
||||||
|
|
||||||
response = requests.get(run.artifacts_url).json()
|
|
||||||
for artifact in response["artifacts"]:
|
|
||||||
if artifact["name"] == name:
|
|
||||||
return str(artifact["archive_download_url"])
|
|
||||||
raise Exception(f"no archive url for {branch} - {name}")
|
|
||||||
|
|
||||||
|
|
||||||
def venv_path(base: str, name: str) -> str:
|
def venv_path(base: str, name: str) -> str:
|
||||||
@ -173,7 +41,7 @@ class Deployer:
|
|||||||
repo: str,
|
repo: str,
|
||||||
unattended: bool,
|
unattended: bool,
|
||||||
):
|
):
|
||||||
self.downloader = Downloader()
|
self.downloader = GithubClient()
|
||||||
self.pr = pr
|
self.pr = pr
|
||||||
self.branch = branch
|
self.branch = branch
|
||||||
self.instance = instance
|
self.instance = instance
|
||||||
|
200
src/utils/check-pr/github_client.py
Normal file
200
src/utils/check-pr/github_client.py
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Optional, Tuple, TypeVar
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from github import Github
|
||||||
|
|
||||||
|
A = TypeVar("A")
|
||||||
|
|
||||||
|
|
||||||
|
def wait(func: Callable[[], Tuple[bool, str, A]], frequency: float = 1.0) -> A:
|
||||||
|
"""
|
||||||
|
Wait until the provided func returns True.
|
||||||
|
|
||||||
|
Provides user feedback via a spinner if stdout is a TTY.
|
||||||
|
"""
|
||||||
|
|
||||||
|
isatty = sys.stdout.isatty()
|
||||||
|
frames = ["-", "\\", "|", "/"]
|
||||||
|
waited = False
|
||||||
|
last_message = None
|
||||||
|
result = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
result = func()
|
||||||
|
if result[0]:
|
||||||
|
break
|
||||||
|
message = result[1]
|
||||||
|
|
||||||
|
if isatty:
|
||||||
|
if last_message:
|
||||||
|
if last_message == message:
|
||||||
|
sys.stdout.write("\b" * (len(last_message) + 2))
|
||||||
|
else:
|
||||||
|
sys.stdout.write("\n")
|
||||||
|
sys.stdout.write("%s %s" % (frames[0], message))
|
||||||
|
sys.stdout.flush()
|
||||||
|
elif last_message != message:
|
||||||
|
print(message, flush=True)
|
||||||
|
|
||||||
|
last_message = message
|
||||||
|
waited = True
|
||||||
|
time.sleep(frequency)
|
||||||
|
frames.sort(key=frames[0].__eq__)
|
||||||
|
finally:
|
||||||
|
if waited and isatty:
|
||||||
|
print(flush=True)
|
||||||
|
|
||||||
|
return result[2]
|
||||||
|
|
||||||
|
|
||||||
|
class GithubClient:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.gh = Github(login_or_token=os.environ["GITHUB_ISSUE_TOKEN"])
|
||||||
|
|
||||||
|
def update_pr(self, repo_name: str, pr: int, update_if_behind: bool = True) -> None:
|
||||||
|
pr_obj = self.gh.get_repo(repo_name).get_pull(pr)
|
||||||
|
if (pr_obj.mergeable_state == "behind") and update_if_behind:
|
||||||
|
print(f"pr:{pr} out of date. Updating")
|
||||||
|
pr_obj.update_branch() # type: ignore
|
||||||
|
time.sleep(5)
|
||||||
|
elif pr_obj.mergeable_state == "dirty":
|
||||||
|
raise Exception(f"merge confict errors on pr:{pr}")
|
||||||
|
|
||||||
|
def merge_pr(self, repo_name: str, pr: int) -> None:
|
||||||
|
pr_obj = self.gh.get_repo(repo_name).get_pull(pr)
|
||||||
|
if pr_obj.mergeable_state == "clean":
|
||||||
|
print(f"merging pr:{pr}")
|
||||||
|
pr_obj.merge(commit_message="", merge_method="squash")
|
||||||
|
else:
|
||||||
|
print(f"unable to merge pr:{pr}", pr_obj.mergeable_state)
|
||||||
|
|
||||||
|
def get_artifact(
|
||||||
|
self,
|
||||||
|
repo_name: str,
|
||||||
|
workflow: str,
|
||||||
|
branch: Optional[str],
|
||||||
|
pr: Optional[int],
|
||||||
|
name: str,
|
||||||
|
file_path: str,
|
||||||
|
update_branch: bool = True,
|
||||||
|
) -> None:
|
||||||
|
print(f"getting {name}")
|
||||||
|
|
||||||
|
if pr:
|
||||||
|
self.update_pr(repo_name, pr, update_branch)
|
||||||
|
branch = self.gh.get_repo(repo_name).get_pull(pr).head.ref
|
||||||
|
if not branch:
|
||||||
|
raise Exception("missing branch")
|
||||||
|
|
||||||
|
zip_file_url = self.get_artifact_url(repo_name, workflow, branch, name)
|
||||||
|
|
||||||
|
(code, resp, _) = self.gh._Github__requester.requestBlob( # type: ignore
|
||||||
|
"GET", zip_file_url, {}
|
||||||
|
)
|
||||||
|
if code != 302:
|
||||||
|
raise Exception(f"unexpected response: {resp}")
|
||||||
|
|
||||||
|
with open(file_path, "wb") as handle:
|
||||||
|
print(f"writing {file_path}")
|
||||||
|
for chunk in requests.get(resp["location"], stream=True).iter_content(
|
||||||
|
chunk_size=1024 * 16
|
||||||
|
):
|
||||||
|
handle.write(chunk)
|
||||||
|
|
||||||
|
def get_artifact_url(
|
||||||
|
self, repo_name: str, workflow_name: str, branch: str, name: str
|
||||||
|
) -> str:
|
||||||
|
repo = self.gh.get_repo(repo_name)
|
||||||
|
workflow = repo.get_workflow(workflow_name)
|
||||||
|
runs = workflow.get_runs()
|
||||||
|
run = None
|
||||||
|
for x in runs:
|
||||||
|
if x.head_branch != branch:
|
||||||
|
continue
|
||||||
|
run = x
|
||||||
|
break
|
||||||
|
if not run:
|
||||||
|
raise Exception("invalid run")
|
||||||
|
|
||||||
|
print("using run from branch", run.head_branch)
|
||||||
|
|
||||||
|
def check() -> Tuple[bool, str, None]:
|
||||||
|
if run is None:
|
||||||
|
raise Exception("invalid run")
|
||||||
|
run.update()
|
||||||
|
return run.status == "completed", run.status, None
|
||||||
|
|
||||||
|
wait(check, frequency=10.0)
|
||||||
|
if run.conclusion != "success":
|
||||||
|
raise Exception(f"bad conclusion: {run.conclusion}")
|
||||||
|
|
||||||
|
response = requests.get(run.artifacts_url).json()
|
||||||
|
for artifact in response["artifacts"]:
|
||||||
|
if artifact["name"] == name:
|
||||||
|
return str(artifact["archive_download_url"])
|
||||||
|
raise Exception(f"no archive url for {branch} - {name}")
|
||||||
|
|
||||||
|
|
||||||
|
def download_artifacts(
|
||||||
|
downloader: GithubClient,
|
||||||
|
repo: str,
|
||||||
|
branch: Optional[str],
|
||||||
|
pr: Optional[int],
|
||||||
|
directory: str,
|
||||||
|
update_branch: bool = True,
|
||||||
|
) -> None:
|
||||||
|
release_filename = "release-artifacts.zip"
|
||||||
|
|
||||||
|
downloader.get_artifact(
|
||||||
|
repo,
|
||||||
|
"ci.yml",
|
||||||
|
branch,
|
||||||
|
pr,
|
||||||
|
"release-artifacts",
|
||||||
|
os.path.join(directory, release_filename),
|
||||||
|
)
|
||||||
|
|
||||||
|
test_filename = "integration-test-artifacts.zip"
|
||||||
|
downloader.get_artifact(
|
||||||
|
repo,
|
||||||
|
"ci.yml",
|
||||||
|
branch,
|
||||||
|
pr,
|
||||||
|
"integration-test-artifacts",
|
||||||
|
os.path.join(directory, test_filename),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
group = parser.add_mutually_exclusive_group()
|
||||||
|
group.add_argument("--branch")
|
||||||
|
group.add_argument("--pr", type=int)
|
||||||
|
parser.add_argument("--repo", default="microsoft/onefuzz")
|
||||||
|
parser.add_argument("--destination", default=os.getcwd())
|
||||||
|
parser.add_argument("--skip_update", action="store_true")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
path = Path(args.destination)
|
||||||
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
downloader = GithubClient()
|
||||||
|
download_artifacts(
|
||||||
|
downloader,
|
||||||
|
args.repo,
|
||||||
|
args.branch,
|
||||||
|
args.pr,
|
||||||
|
args.destination,
|
||||||
|
not args.skip_update,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Reference in New Issue
Block a user