mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-15 19:38:11 +00:00
@ -29,13 +29,6 @@ echo "layout python3" >> .envrc
|
||||
direnv allow
|
||||
pip install -e .
|
||||
|
||||
echo "Install api-service"
|
||||
cd /workspaces/onefuzz/src/api-service
|
||||
echo "layout python3" >> .envrc
|
||||
direnv allow
|
||||
pip install -r requirements-dev.txt
|
||||
cd __app__
|
||||
pip install -r requirements.txt
|
||||
|
||||
cd /workspaces/onefuzz/src/utils
|
||||
chmod u+x lint.sh
|
||||
|
1
.github/codeql/codeql-config.yml
vendored
1
.github/codeql/codeql-config.yml
vendored
@ -3,4 +3,3 @@ paths:
|
||||
- src/agent
|
||||
- src/pytypes
|
||||
- src/deployment
|
||||
- src/api-service/__app__
|
||||
|
@ -1,7 +0,0 @@
|
||||
[flake8]
|
||||
# Recommend matching the black line length (default 88),
|
||||
# rather than using the flake8 default of 79:
|
||||
max-line-length = 88
|
||||
extend-ignore =
|
||||
# See https://github.com/PyCQA/pycodestyle/issues/373
|
||||
E203,
|
49
src/api-service/.gitignore
vendored
49
src/api-service/.gitignore
vendored
@ -1,49 +0,0 @@
|
||||
bin
|
||||
obj
|
||||
csx
|
||||
.vs
|
||||
edge
|
||||
Publish
|
||||
|
||||
*.user
|
||||
*.suo
|
||||
*.cscfg
|
||||
*.Cache
|
||||
project.lock.json
|
||||
|
||||
/packages
|
||||
/TestResults
|
||||
|
||||
/tools/NuGet.exe
|
||||
/App_Data
|
||||
/secrets
|
||||
/data
|
||||
.secrets
|
||||
appsettings.json
|
||||
local.settings.json
|
||||
|
||||
node_modules
|
||||
dist
|
||||
|
||||
# Local python packages
|
||||
.python_packages/
|
||||
|
||||
# Python Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
.dockerenv
|
||||
host_secrets.json
|
||||
local.settings.json
|
||||
.mypy_cache/
|
||||
|
||||
__app__/onefuzzlib/build.id
|
@ -1,9 +0,0 @@
|
||||
If you are doing development on the API service, you can build and deploy directly to your own instance using the azure-functions-core-tools.
|
||||
|
||||
From the api-service directory, do the following:
|
||||
|
||||
func azure functionapp publish <instance>
|
||||
|
||||
While Azure Functions will restart your instance with the new code, it may take a while. It may be helpful to restart your instance after pushing by doing the following:
|
||||
|
||||
az functionapp restart -g <group> -n <instance>
|
@ -1 +0,0 @@
|
||||
local.settings.json
|
4
src/api-service/__app__/.gitignore
vendored
4
src/api-service/__app__/.gitignore
vendored
@ -1,4 +0,0 @@
|
||||
.direnv
|
||||
.python_packages
|
||||
__pycache__
|
||||
.venv
|
@ -1,17 +0,0 @@
|
||||
import os
|
||||
|
||||
import certifi.core
|
||||
|
||||
|
||||
def override_where() -> str:
|
||||
"""overrides certifi.core.where to return actual location of cacert.pem"""
|
||||
# see:
|
||||
# https://github.com/Azure/azure-functions-durable-python/issues/194#issuecomment-710670377
|
||||
# change this to match the location of cacert.pem
|
||||
return os.path.abspath(
|
||||
"cacert.pem"
|
||||
) # or to whatever location you know contains the copy of cacert.pem
|
||||
|
||||
|
||||
os.environ["REQUESTS_CA_BUNDLE"] = override_where()
|
||||
certifi.core.where = override_where
|
@ -1,53 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import azure.functions as func
|
||||
from onefuzztypes.enums import ErrorCode, TaskState
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.requests import CanScheduleRequest
|
||||
from onefuzztypes.responses import CanSchedule
|
||||
|
||||
from ..onefuzzlib.endpoint_authorization import call_if_agent
|
||||
from ..onefuzzlib.request import not_ok, ok, parse_request
|
||||
from ..onefuzzlib.tasks.main import Task
|
||||
from ..onefuzzlib.workers.nodes import Node
|
||||
|
||||
|
||||
def post(req: func.HttpRequest) -> func.HttpResponse:
|
||||
request = parse_request(CanScheduleRequest, req)
|
||||
if isinstance(request, Error):
|
||||
return not_ok(request, context="CanScheduleRequest")
|
||||
|
||||
node = Node.get_by_machine_id(request.machine_id)
|
||||
if not node:
|
||||
return not_ok(
|
||||
Error(code=ErrorCode.UNABLE_TO_FIND, errors=["unable to find node"]),
|
||||
context=request.machine_id,
|
||||
)
|
||||
|
||||
allowed = True
|
||||
work_stopped = False
|
||||
|
||||
if not node.can_process_new_work():
|
||||
allowed = False
|
||||
|
||||
task = Task.get_by_task_id(request.task_id)
|
||||
|
||||
work_stopped = isinstance(task, Error) or task.state in TaskState.shutting_down()
|
||||
if work_stopped:
|
||||
allowed = False
|
||||
|
||||
if allowed:
|
||||
allowed = not isinstance(node.acquire_scale_in_protection(), Error)
|
||||
|
||||
return ok(CanSchedule(allowed=allowed, work_stopped=work_stopped))
|
||||
|
||||
|
||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
||||
methods = {"POST": post}
|
||||
method = methods[req.method]
|
||||
result = call_if_agent(req, method)
|
||||
|
||||
return result
|
@ -1,20 +0,0 @@
|
||||
{
|
||||
"scriptFile": "__init__.py",
|
||||
"bindings": [
|
||||
{
|
||||
"authLevel": "anonymous",
|
||||
"type": "httpTrigger",
|
||||
"direction": "in",
|
||||
"name": "req",
|
||||
"methods": [
|
||||
"post"
|
||||
],
|
||||
"route": "agents/can_schedule"
|
||||
},
|
||||
{
|
||||
"type": "http",
|
||||
"direction": "out",
|
||||
"name": "$return"
|
||||
}
|
||||
]
|
||||
}
|
@ -1,50 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import azure.functions as func
|
||||
from onefuzztypes.models import Error, NodeCommandEnvelope
|
||||
from onefuzztypes.requests import NodeCommandDelete, NodeCommandGet
|
||||
from onefuzztypes.responses import BoolResult, PendingNodeCommand
|
||||
|
||||
from ..onefuzzlib.endpoint_authorization import call_if_agent
|
||||
from ..onefuzzlib.request import not_ok, ok, parse_request
|
||||
from ..onefuzzlib.workers.nodes import NodeMessage
|
||||
|
||||
|
||||
def get(req: func.HttpRequest) -> func.HttpResponse:
|
||||
request = parse_request(NodeCommandGet, req)
|
||||
|
||||
if isinstance(request, Error):
|
||||
return not_ok(request, context="NodeCommandGet")
|
||||
|
||||
messages = NodeMessage.get_messages(request.machine_id, num_results=1)
|
||||
|
||||
if messages:
|
||||
command = messages[0].message
|
||||
message_id = messages[0].message_id
|
||||
envelope = NodeCommandEnvelope(command=command, message_id=message_id)
|
||||
|
||||
return ok(PendingNodeCommand(envelope=envelope))
|
||||
else:
|
||||
return ok(PendingNodeCommand(envelope=None))
|
||||
|
||||
|
||||
def delete(req: func.HttpRequest) -> func.HttpResponse:
|
||||
request = parse_request(NodeCommandDelete, req)
|
||||
if isinstance(request, Error):
|
||||
return not_ok(request, context="NodeCommandDelete")
|
||||
|
||||
message = NodeMessage.get(request.machine_id, request.message_id)
|
||||
if message:
|
||||
message.delete()
|
||||
return ok(BoolResult(result=True))
|
||||
|
||||
|
||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
||||
methods = {"DELETE": delete, "GET": get}
|
||||
method = methods[req.method]
|
||||
result = call_if_agent(req, method)
|
||||
|
||||
return result
|
@ -1,22 +0,0 @@
|
||||
{
|
||||
"scriptFile": "__init__.py",
|
||||
"bindings": [
|
||||
{
|
||||
"authLevel": "anonymous",
|
||||
"type": "httpTrigger",
|
||||
"direction": "in",
|
||||
"name": "req",
|
||||
"methods": [
|
||||
"get",
|
||||
"post",
|
||||
"delete"
|
||||
],
|
||||
"route": "agents/commands"
|
||||
},
|
||||
{
|
||||
"type": "http",
|
||||
"direction": "out",
|
||||
"name": "$return"
|
||||
}
|
||||
]
|
||||
}
|
@ -1,79 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
|
||||
import azure.functions as func
|
||||
from onefuzztypes.models import (
|
||||
Error,
|
||||
NodeEvent,
|
||||
NodeEventEnvelope,
|
||||
NodeStateUpdate,
|
||||
Result,
|
||||
WorkerEvent,
|
||||
)
|
||||
from onefuzztypes.responses import BoolResult
|
||||
|
||||
from ..onefuzzlib.agent_events import on_state_update, on_worker_event
|
||||
from ..onefuzzlib.endpoint_authorization import call_if_agent
|
||||
from ..onefuzzlib.request import not_ok, ok, parse_request
|
||||
|
||||
|
||||
def process(envelope: NodeEventEnvelope) -> Result[None]:
|
||||
logging.info(
|
||||
"node event: machine_id: %s event: %s",
|
||||
envelope.machine_id,
|
||||
envelope.event.json(exclude_none=True),
|
||||
)
|
||||
|
||||
if isinstance(envelope.event, NodeStateUpdate):
|
||||
return on_state_update(envelope.machine_id, envelope.event)
|
||||
|
||||
if isinstance(envelope.event, WorkerEvent):
|
||||
return on_worker_event(envelope.machine_id, envelope.event)
|
||||
|
||||
if isinstance(envelope.event, NodeEvent):
|
||||
if envelope.event.state_update:
|
||||
result = on_state_update(envelope.machine_id, envelope.event.state_update)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
if envelope.event.worker_event:
|
||||
result = on_worker_event(envelope.machine_id, envelope.event.worker_event)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
raise NotImplementedError("invalid node event: %s" % envelope)
|
||||
|
||||
|
||||
def post(req: func.HttpRequest) -> func.HttpResponse:
|
||||
envelope = parse_request(NodeEventEnvelope, req)
|
||||
if isinstance(envelope, Error):
|
||||
return not_ok(envelope, context="node event")
|
||||
|
||||
logging.info(
|
||||
"node event: machine_id: %s event: %s",
|
||||
envelope.machine_id,
|
||||
envelope.event.json(exclude_none=True),
|
||||
)
|
||||
|
||||
result = process(envelope)
|
||||
if isinstance(result, Error):
|
||||
logging.error(
|
||||
"unable to process agent event. envelope:%s error:%s", envelope, result
|
||||
)
|
||||
return not_ok(result, context="node event")
|
||||
|
||||
return ok(BoolResult(result=True))
|
||||
|
||||
|
||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
||||
methods = {"POST": post}
|
||||
method = methods[req.method]
|
||||
result = call_if_agent(req, method)
|
||||
|
||||
return result
|
@ -1,20 +0,0 @@
|
||||
{
|
||||
"scriptFile": "__init__.py",
|
||||
"bindings": [
|
||||
{
|
||||
"authLevel": "anonymous",
|
||||
"type": "httpTrigger",
|
||||
"direction": "in",
|
||||
"name": "req",
|
||||
"methods": [
|
||||
"post"
|
||||
],
|
||||
"route": "agents/events"
|
||||
},
|
||||
{
|
||||
"type": "http",
|
||||
"direction": "out",
|
||||
"name": "$return"
|
||||
}
|
||||
]
|
||||
}
|
@ -1,119 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
import azure.functions as func
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.requests import AgentRegistrationGet, AgentRegistrationPost
|
||||
from onefuzztypes.responses import AgentRegistration
|
||||
|
||||
from ..onefuzzlib.azure.creds import get_agent_instance_url
|
||||
from ..onefuzzlib.azure.queue import get_queue_sas
|
||||
from ..onefuzzlib.azure.storage import StorageType
|
||||
from ..onefuzzlib.endpoint_authorization import call_if_agent
|
||||
from ..onefuzzlib.request import not_ok, ok, parse_uri
|
||||
from ..onefuzzlib.workers.nodes import Node
|
||||
from ..onefuzzlib.workers.pools import Pool
|
||||
|
||||
|
||||
def create_registration_response(machine_id: UUID, pool: Pool) -> func.HttpResponse:
|
||||
base_address = get_agent_instance_url()
|
||||
events_url = "%s/api/agents/events" % base_address
|
||||
commands_url = "%s/api/agents/commands" % base_address
|
||||
work_queue = get_queue_sas(
|
||||
pool.get_pool_queue(),
|
||||
StorageType.corpus,
|
||||
read=True,
|
||||
update=True,
|
||||
process=True,
|
||||
duration=datetime.timedelta(hours=24),
|
||||
)
|
||||
return ok(
|
||||
AgentRegistration(
|
||||
events_url=events_url,
|
||||
commands_url=commands_url,
|
||||
work_queue=work_queue,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get(req: func.HttpRequest) -> func.HttpResponse:
|
||||
get_registration = parse_uri(AgentRegistrationGet, req)
|
||||
|
||||
if isinstance(get_registration, Error):
|
||||
return not_ok(get_registration, context="agent registration")
|
||||
|
||||
agent_node = Node.get_by_machine_id(get_registration.machine_id)
|
||||
|
||||
if agent_node is None:
|
||||
return not_ok(
|
||||
Error(
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
errors=[
|
||||
"unable to find a registration associated with machine_id '%s'"
|
||||
% get_registration.machine_id
|
||||
],
|
||||
),
|
||||
context="agent registration",
|
||||
)
|
||||
else:
|
||||
pool = Pool.get_by_name(agent_node.pool_name)
|
||||
if isinstance(pool, Error):
|
||||
return not_ok(
|
||||
Error(
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
errors=[
|
||||
"unable to find a pool associated with the provided machine_id"
|
||||
],
|
||||
),
|
||||
context="agent registration",
|
||||
)
|
||||
|
||||
return create_registration_response(agent_node.machine_id, pool)
|
||||
|
||||
|
||||
def post(req: func.HttpRequest) -> func.HttpResponse:
|
||||
registration_request = parse_uri(AgentRegistrationPost, req)
|
||||
if isinstance(registration_request, Error):
|
||||
return not_ok(registration_request, context="agent registration")
|
||||
logging.info(
|
||||
"registration request: %s", registration_request.json(exclude_none=True)
|
||||
)
|
||||
|
||||
pool = Pool.get_by_name(registration_request.pool_name)
|
||||
if isinstance(pool, Error):
|
||||
return not_ok(
|
||||
Error(
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
errors=["unable to find pool '%s'" % registration_request.pool_name],
|
||||
),
|
||||
context="agent registration",
|
||||
)
|
||||
|
||||
node = Node.get_by_machine_id(registration_request.machine_id)
|
||||
if node:
|
||||
node.delete()
|
||||
|
||||
node = Node.create(
|
||||
pool_id=pool.pool_id,
|
||||
pool_name=pool.name,
|
||||
machine_id=registration_request.machine_id,
|
||||
scaleset_id=registration_request.scaleset_id,
|
||||
version=registration_request.version,
|
||||
)
|
||||
|
||||
return create_registration_response(node.machine_id, pool)
|
||||
|
||||
|
||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
||||
methods = {"POST": post, "GET": get}
|
||||
method = methods[req.method]
|
||||
result = call_if_agent(req, method)
|
||||
|
||||
return result
|
@ -1,22 +0,0 @@
|
||||
{
|
||||
"scriptFile": "__init__.py",
|
||||
"bindings": [
|
||||
{
|
||||
"authLevel": "anonymous",
|
||||
"type": "httpTrigger",
|
||||
"direction": "in",
|
||||
"name": "req",
|
||||
"methods": [
|
||||
"get",
|
||||
"post",
|
||||
"delete"
|
||||
],
|
||||
"route": "agents/registration"
|
||||
},
|
||||
{
|
||||
"type": "http",
|
||||
"direction": "out",
|
||||
"name": "$return"
|
||||
}
|
||||
]
|
||||
}
|
@ -1,96 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
import azure.functions as func
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.requests import ContainerCreate, ContainerDelete, ContainerGet
|
||||
from onefuzztypes.responses import BoolResult, ContainerInfo, ContainerInfoBase
|
||||
|
||||
from ..onefuzzlib.azure.containers import (
|
||||
create_container,
|
||||
delete_container,
|
||||
get_container_metadata,
|
||||
get_container_sas_url,
|
||||
get_containers,
|
||||
)
|
||||
from ..onefuzzlib.azure.storage import StorageType
|
||||
from ..onefuzzlib.endpoint_authorization import call_if_user
|
||||
from ..onefuzzlib.request import not_ok, ok, parse_request
|
||||
|
||||
|
||||
def get(req: func.HttpRequest) -> func.HttpResponse:
|
||||
request: Optional[Union[ContainerGet, Error]] = None
|
||||
if req.get_body():
|
||||
request = parse_request(ContainerGet, req)
|
||||
|
||||
if isinstance(request, Error):
|
||||
return not_ok(request, context="container get")
|
||||
if request is not None:
|
||||
metadata = get_container_metadata(request.name, StorageType.corpus)
|
||||
if metadata is None:
|
||||
return not_ok(
|
||||
Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid container"]),
|
||||
context=request.name,
|
||||
)
|
||||
|
||||
info = ContainerInfo(
|
||||
name=request.name,
|
||||
sas_url=get_container_sas_url(
|
||||
request.name,
|
||||
StorageType.corpus,
|
||||
read=True,
|
||||
write=True,
|
||||
delete=True,
|
||||
list_=True,
|
||||
),
|
||||
metadata=metadata,
|
||||
)
|
||||
return ok(info)
|
||||
|
||||
containers = get_containers(StorageType.corpus)
|
||||
|
||||
container_info = []
|
||||
for name in containers:
|
||||
container_info.append(ContainerInfoBase(name=name, metadata=containers[name]))
|
||||
|
||||
return ok(container_info)
|
||||
|
||||
|
||||
def post(req: func.HttpRequest) -> func.HttpResponse:
|
||||
request = parse_request(ContainerCreate, req)
|
||||
if isinstance(request, Error):
|
||||
return not_ok(request, context="container create")
|
||||
|
||||
logging.info("container - creating %s", request.name)
|
||||
sas = create_container(request.name, StorageType.corpus, metadata=request.metadata)
|
||||
if sas:
|
||||
return ok(
|
||||
ContainerInfo(name=request.name, sas_url=sas, metadata=request.metadata)
|
||||
)
|
||||
return not_ok(
|
||||
Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid container"]),
|
||||
context=request.name,
|
||||
)
|
||||
|
||||
|
||||
def delete(req: func.HttpRequest) -> func.HttpResponse:
|
||||
request = parse_request(ContainerDelete, req)
|
||||
if isinstance(request, Error):
|
||||
return not_ok(request, context="container delete")
|
||||
|
||||
logging.info("container - deleting %s", request.name)
|
||||
return ok(BoolResult(result=delete_container(request.name, StorageType.corpus)))
|
||||
|
||||
|
||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
||||
methods = {"GET": get, "POST": post, "DELETE": delete}
|
||||
method = methods[req.method]
|
||||
result = call_if_user(req, method)
|
||||
|
||||
return result
|
@ -1,21 +0,0 @@
|
||||
{
|
||||
"scriptFile": "__init__.py",
|
||||
"bindings": [
|
||||
{
|
||||
"authLevel": "anonymous",
|
||||
"type": "httpTrigger",
|
||||
"direction": "in",
|
||||
"name": "req",
|
||||
"methods": [
|
||||
"get",
|
||||
"post",
|
||||
"delete"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "http",
|
||||
"direction": "out",
|
||||
"name": "$return"
|
||||
}
|
||||
]
|
||||
}
|
@ -1,55 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
import azure.functions as func
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error, FileEntry
|
||||
|
||||
from ..onefuzzlib.azure.containers import (
|
||||
blob_exists,
|
||||
container_exists,
|
||||
get_file_sas_url,
|
||||
)
|
||||
from ..onefuzzlib.azure.storage import StorageType
|
||||
from ..onefuzzlib.endpoint_authorization import call_if_user
|
||||
from ..onefuzzlib.request import not_ok, parse_uri, redirect
|
||||
|
||||
|
||||
def get(req: func.HttpRequest) -> func.HttpResponse:
|
||||
request = parse_uri(FileEntry, req)
|
||||
if isinstance(request, Error):
|
||||
return not_ok(request, context="download")
|
||||
|
||||
if not container_exists(request.container, StorageType.corpus):
|
||||
return not_ok(
|
||||
Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid container"]),
|
||||
context=request.container,
|
||||
)
|
||||
|
||||
if not blob_exists(request.container, request.filename, StorageType.corpus):
|
||||
return not_ok(
|
||||
Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid filename"]),
|
||||
context=request.filename,
|
||||
)
|
||||
|
||||
return redirect(
|
||||
get_file_sas_url(
|
||||
request.container,
|
||||
request.filename,
|
||||
StorageType.corpus,
|
||||
read=True,
|
||||
duration=timedelta(minutes=5),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
||||
methods = {"GET": get}
|
||||
method = methods[req.method]
|
||||
result = call_if_user(req, method)
|
||||
|
||||
return result
|
@ -1,19 +0,0 @@
|
||||
{
|
||||
"scriptFile": "__init__.py",
|
||||
"bindings": [
|
||||
{
|
||||
"authLevel": "anonymous",
|
||||
"type": "httpTrigger",
|
||||
"direction": "in",
|
||||
"name": "req",
|
||||
"methods": [
|
||||
"get"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "http",
|
||||
"direction": "out",
|
||||
"name": "$return"
|
||||
}
|
||||
]
|
||||
}
|
@ -1,34 +0,0 @@
|
||||
{
|
||||
"version": "2.0",
|
||||
"extensionBundle": {
|
||||
"id": "Microsoft.Azure.Functions.ExtensionBundle",
|
||||
"version": "[1.*, 2.0.0)"
|
||||
},
|
||||
"logging": {
|
||||
"logLevel": {
|
||||
"default": "Warning",
|
||||
"Host": "Warning",
|
||||
"Function": "Information",
|
||||
"Host.Aggregator": "Warning",
|
||||
"logging:logLevel:Worker": "Warning",
|
||||
"logging:logLevel:Microsoft": "Warning",
|
||||
"AzureFunctionsJobHost:logging:logLevel:Host.Function.Console": "Warning"
|
||||
},
|
||||
"applicationInsights": {
|
||||
"samplingSettings": {
|
||||
"isEnabled": true,
|
||||
"InitialSamplingPercentage": 100,
|
||||
"maxTelemetryItemsPerSecond": 20,
|
||||
"excludedTypes": "Exception"
|
||||
}
|
||||
}
|
||||
},
|
||||
"extensions": {
|
||||
"queues": {
|
||||
"maxPollingInterval": "00:00:02",
|
||||
"batchSize": 32,
|
||||
"maxDequeueCount": 5
|
||||
}
|
||||
},
|
||||
"functionTimeout": "00:15:00"
|
||||
}
|
@ -1,43 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import azure.functions as func
|
||||
from onefuzztypes.responses import Info
|
||||
|
||||
from ..onefuzzlib.azure.creds import (
|
||||
get_base_region,
|
||||
get_base_resource_group,
|
||||
get_insights_appid,
|
||||
get_insights_instrumentation_key,
|
||||
get_instance_id,
|
||||
get_subscription,
|
||||
)
|
||||
from ..onefuzzlib.endpoint_authorization import call_if_user
|
||||
from ..onefuzzlib.request import ok
|
||||
from ..onefuzzlib.versions import versions
|
||||
|
||||
|
||||
def get(req: func.HttpRequest) -> func.HttpResponse:
|
||||
response = ok(
|
||||
Info(
|
||||
resource_group=get_base_resource_group(),
|
||||
region=get_base_region(),
|
||||
subscription=get_subscription(),
|
||||
versions=versions(),
|
||||
instance_id=get_instance_id(),
|
||||
insights_appid=get_insights_appid(),
|
||||
insights_instrumentation_key=get_insights_instrumentation_key(),
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
||||
methods = {"GET": get}
|
||||
method = methods[req.method]
|
||||
result = call_if_user(req, method)
|
||||
|
||||
return result
|
@ -1,19 +0,0 @@
|
||||
{
|
||||
"scriptFile": "__init__.py",
|
||||
"bindings": [
|
||||
{
|
||||
"authLevel": "anonymous",
|
||||
"type": "httpTrigger",
|
||||
"direction": "in",
|
||||
"name": "req",
|
||||
"methods": [
|
||||
"get"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "http",
|
||||
"direction": "out",
|
||||
"name": "$return"
|
||||
}
|
||||
]
|
||||
}
|
@ -1,77 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
|
||||
import azure.functions as func
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.requests import InstanceConfigUpdate
|
||||
|
||||
from ..onefuzzlib.azure.nsg import is_onefuzz_nsg, list_nsgs, set_allowed
|
||||
from ..onefuzzlib.config import InstanceConfig
|
||||
from ..onefuzzlib.endpoint_authorization import call_if_user, can_modify_config
|
||||
from ..onefuzzlib.request import not_ok, ok, parse_request
|
||||
|
||||
|
||||
def get(req: func.HttpRequest) -> func.HttpResponse:
|
||||
return ok(InstanceConfig.fetch())
|
||||
|
||||
|
||||
def post(req: func.HttpRequest) -> func.HttpResponse:
|
||||
request = parse_request(InstanceConfigUpdate, req)
|
||||
if isinstance(request, Error):
|
||||
return not_ok(request, context="instance_config_update")
|
||||
|
||||
config = InstanceConfig.fetch()
|
||||
|
||||
if not can_modify_config(req, config):
|
||||
return not_ok(
|
||||
Error(code=ErrorCode.INVALID_PERMISSION, errors=["unauthorized"]),
|
||||
context="instance_config_update",
|
||||
)
|
||||
|
||||
update_nsg = False
|
||||
if request.config.proxy_nsg_config and config.proxy_nsg_config:
|
||||
request_config = request.config.proxy_nsg_config
|
||||
current_config = config.proxy_nsg_config
|
||||
if set(request_config.allowed_service_tags) != set(
|
||||
current_config.allowed_service_tags
|
||||
) or set(request_config.allowed_ips) != set(current_config.allowed_ips):
|
||||
update_nsg = True
|
||||
|
||||
config.update(request.config)
|
||||
config.save()
|
||||
|
||||
# Update All NSGs
|
||||
if update_nsg:
|
||||
nsgs = list_nsgs()
|
||||
for nsg in nsgs:
|
||||
logging.info(
|
||||
"Checking if nsg: %s (%s) owned by OneFuzz" % (nsg.location, nsg.name)
|
||||
)
|
||||
if is_onefuzz_nsg(nsg.location, nsg.name):
|
||||
result = set_allowed(nsg.location, request.config.proxy_nsg_config)
|
||||
if isinstance(result, Error):
|
||||
return not_ok(
|
||||
Error(
|
||||
code=ErrorCode.UNABLE_TO_CREATE,
|
||||
errors=[
|
||||
"Unable to update nsg %s due to %s"
|
||||
% (nsg.location, result)
|
||||
],
|
||||
),
|
||||
context="instance_config_update",
|
||||
)
|
||||
|
||||
return ok(config)
|
||||
|
||||
|
||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
||||
methods = {"GET": get, "POST": post}
|
||||
method = methods[req.method]
|
||||
result = call_if_user(req, method)
|
||||
|
||||
return result
|
@ -1,20 +0,0 @@
|
||||
{
|
||||
"scriptFile": "__init__.py",
|
||||
"bindings": [
|
||||
{
|
||||
"authLevel": "anonymous",
|
||||
"type": "httpTrigger",
|
||||
"direction": "in",
|
||||
"name": "req",
|
||||
"methods": [
|
||||
"get",
|
||||
"post"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "http",
|
||||
"direction": "out",
|
||||
"name": "$return"
|
||||
}
|
||||
]
|
||||
}
|
@ -1,42 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import azure.functions as func
|
||||
from onefuzztypes.job_templates import JobTemplateRequest
|
||||
from onefuzztypes.models import Error
|
||||
|
||||
from ..onefuzzlib.endpoint_authorization import call_if_user
|
||||
from ..onefuzzlib.job_templates.templates import JobTemplateIndex
|
||||
from ..onefuzzlib.request import not_ok, ok, parse_request
|
||||
from ..onefuzzlib.user_credentials import parse_jwt_token
|
||||
|
||||
|
||||
def get(req: func.HttpRequest) -> func.HttpResponse:
|
||||
configs = JobTemplateIndex.get_configs()
|
||||
return ok(configs)
|
||||
|
||||
|
||||
def post(req: func.HttpRequest) -> func.HttpResponse:
|
||||
request = parse_request(JobTemplateRequest, req)
|
||||
if isinstance(request, Error):
|
||||
return not_ok(request, context="JobTemplateRequest")
|
||||
|
||||
user_info = parse_jwt_token(req)
|
||||
if isinstance(user_info, Error):
|
||||
return not_ok(user_info, context="JobTemplateRequest")
|
||||
|
||||
job = JobTemplateIndex.execute(request, user_info)
|
||||
if isinstance(job, Error):
|
||||
return not_ok(job, context="JobTemplateRequest")
|
||||
|
||||
return ok(job)
|
||||
|
||||
|
||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
||||
methods = {"GET": get, "POST": post}
|
||||
method = methods[req.method]
|
||||
result = call_if_user(req, method)
|
||||
|
||||
return result
|
@ -1,20 +0,0 @@
|
||||
{
|
||||
"scriptFile": "__init__.py",
|
||||
"bindings": [
|
||||
{
|
||||
"authLevel": "anonymous",
|
||||
"type": "httpTrigger",
|
||||
"direction": "in",
|
||||
"name": "req",
|
||||
"methods": [
|
||||
"get",
|
||||
"post"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "http",
|
||||
"direction": "out",
|
||||
"name": "$return"
|
||||
}
|
||||
]
|
||||
}
|
@ -1,69 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import azure.functions as func
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.job_templates import (
|
||||
JobTemplateDelete,
|
||||
JobTemplateGet,
|
||||
JobTemplateUpload,
|
||||
)
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.responses import BoolResult
|
||||
|
||||
from ..onefuzzlib.endpoint_authorization import call_if_user
|
||||
from ..onefuzzlib.job_templates.templates import JobTemplateIndex
|
||||
from ..onefuzzlib.request import not_ok, ok, parse_request
|
||||
|
||||
|
||||
def get(req: func.HttpRequest) -> func.HttpResponse:
|
||||
request = parse_request(JobTemplateGet, req)
|
||||
if isinstance(request, Error):
|
||||
return not_ok(request, context="JobTemplateGet")
|
||||
|
||||
if request.name:
|
||||
entry = JobTemplateIndex.get_base_entry(request.name)
|
||||
if entry is None:
|
||||
return not_ok(
|
||||
Error(code=ErrorCode.INVALID_REQUEST, errors=["no such job template"]),
|
||||
context="JobTemplateGet",
|
||||
)
|
||||
return ok(entry.template)
|
||||
|
||||
templates = JobTemplateIndex.get_index()
|
||||
return ok(templates)
|
||||
|
||||
|
||||
def post(req: func.HttpRequest) -> func.HttpResponse:
|
||||
request = parse_request(JobTemplateUpload, req)
|
||||
if isinstance(request, Error):
|
||||
return not_ok(request, context="JobTemplateUpload")
|
||||
|
||||
entry = JobTemplateIndex(name=request.name, template=request.template)
|
||||
result = entry.save()
|
||||
if isinstance(result, Error):
|
||||
return not_ok(result, context="JobTemplateUpload")
|
||||
|
||||
return ok(BoolResult(result=True))
|
||||
|
||||
|
||||
def delete(req: func.HttpRequest) -> func.HttpResponse:
|
||||
request = parse_request(JobTemplateDelete, req)
|
||||
if isinstance(request, Error):
|
||||
return not_ok(request, context="JobTemplateDelete")
|
||||
|
||||
entry = JobTemplateIndex.get(request.name)
|
||||
if entry is not None:
|
||||
entry.delete()
|
||||
|
||||
return ok(BoolResult(result=entry is not None))
|
||||
|
||||
|
||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
||||
methods = {"GET": get, "POST": post, "DELETE": delete}
|
||||
method = methods[req.method]
|
||||
result = call_if_user(req, method)
|
||||
|
||||
return result
|
@ -1,22 +0,0 @@
|
||||
{
|
||||
"scriptFile": "__init__.py",
|
||||
"bindings": [
|
||||
{
|
||||
"authLevel": "anonymous",
|
||||
"type": "httpTrigger",
|
||||
"direction": "in",
|
||||
"name": "req",
|
||||
"methods": [
|
||||
"get",
|
||||
"post",
|
||||
"delete"
|
||||
],
|
||||
"route": "job_templates/manage"
|
||||
},
|
||||
{
|
||||
"type": "http",
|
||||
"direction": "out",
|
||||
"name": "$return"
|
||||
}
|
||||
]
|
||||
}
|
@ -1,109 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import azure.functions as func
|
||||
from onefuzztypes.enums import ContainerType, ErrorCode, JobState
|
||||
from onefuzztypes.models import Error, JobConfig, JobTaskInfo
|
||||
from onefuzztypes.primitives import Container
|
||||
from onefuzztypes.requests import JobGet, JobSearch
|
||||
|
||||
from ..onefuzzlib.azure.containers import create_container
|
||||
from ..onefuzzlib.azure.storage import StorageType
|
||||
from ..onefuzzlib.endpoint_authorization import call_if_user
|
||||
from ..onefuzzlib.jobs import Job
|
||||
from ..onefuzzlib.request import not_ok, ok, parse_request
|
||||
from ..onefuzzlib.tasks.main import Task
|
||||
from ..onefuzzlib.user_credentials import parse_jwt_token
|
||||
|
||||
|
||||
def get(req: func.HttpRequest) -> func.HttpResponse:
|
||||
request = parse_request(JobSearch, req)
|
||||
if isinstance(request, Error):
|
||||
return not_ok(request, context="jobs")
|
||||
|
||||
if request.job_id:
|
||||
job = Job.get(request.job_id)
|
||||
if not job:
|
||||
return not_ok(
|
||||
Error(code=ErrorCode.INVALID_JOB, errors=["no such job"]),
|
||||
context=request.job_id,
|
||||
)
|
||||
task_info = []
|
||||
for task in Task.search_states(job_id=request.job_id):
|
||||
task_info.append(
|
||||
JobTaskInfo(
|
||||
task_id=task.task_id, type=task.config.task.type, state=task.state
|
||||
)
|
||||
)
|
||||
job.task_info = task_info
|
||||
return ok(job)
|
||||
|
||||
jobs = Job.search_states(states=request.state)
|
||||
return ok(jobs)
|
||||
|
||||
|
||||
def post(req: func.HttpRequest) -> func.HttpResponse:
|
||||
request = parse_request(JobConfig, req)
|
||||
if isinstance(request, Error):
|
||||
return not_ok(request, context="jobs create")
|
||||
|
||||
user_info = parse_jwt_token(req)
|
||||
if isinstance(user_info, Error):
|
||||
return not_ok(user_info, context="jobs create")
|
||||
|
||||
job = Job(job_id=uuid4(), config=request, user_info=user_info)
|
||||
# create the job logs container
|
||||
log_container_sas = create_container(
|
||||
Container(f"logs-{job.job_id}"),
|
||||
StorageType.corpus,
|
||||
metadata={"container_type": ContainerType.logs.name},
|
||||
)
|
||||
if not log_container_sas:
|
||||
return not_ok(
|
||||
Error(
|
||||
code=ErrorCode.UNABLE_TO_CREATE_CONTAINER,
|
||||
errors=["unable to create logs container"],
|
||||
),
|
||||
context="logs",
|
||||
)
|
||||
sep_index = log_container_sas.find("?")
|
||||
if sep_index > 0:
|
||||
log_container = log_container_sas[:sep_index]
|
||||
else:
|
||||
log_container = log_container_sas
|
||||
|
||||
job.config.logs = log_container
|
||||
job.save()
|
||||
|
||||
return ok(job)
|
||||
|
||||
|
||||
def delete(req: func.HttpRequest) -> func.HttpResponse:
|
||||
request = parse_request(JobGet, req)
|
||||
if isinstance(request, Error):
|
||||
return not_ok(request, context="jobs delete")
|
||||
|
||||
job = Job.get(request.job_id)
|
||||
if not job:
|
||||
return not_ok(
|
||||
Error(code=ErrorCode.INVALID_JOB, errors=["no such job"]),
|
||||
context=request.job_id,
|
||||
)
|
||||
|
||||
if job.state not in [JobState.stopped, JobState.stopping]:
|
||||
job.state = JobState.stopping
|
||||
job.save()
|
||||
|
||||
return ok(job)
|
||||
|
||||
|
||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
||||
methods = {"GET": get, "POST": post, "DELETE": delete}
|
||||
method = methods[req.method]
|
||||
result = call_if_user(req, method)
|
||||
|
||||
return result
|
@ -1,21 +0,0 @@
|
||||
{
|
||||
"scriptFile": "__init__.py",
|
||||
"bindings": [
|
||||
{
|
||||
"authLevel": "anonymous",
|
||||
"type": "httpTrigger",
|
||||
"direction": "in",
|
||||
"name": "req",
|
||||
"methods": [
|
||||
"get",
|
||||
"post",
|
||||
"delete"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "http",
|
||||
"direction": "out",
|
||||
"name": "$return"
|
||||
}
|
||||
]
|
||||
}
|
@ -1,34 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import azure.functions as func
|
||||
|
||||
from ..onefuzzlib.endpoint_authorization import call_if_user
|
||||
|
||||
# This endpoint handles the signalr negotation
|
||||
# As we do not differentiate from clients at this time, we pass the Functions runtime
|
||||
# provided connection straight to the client
|
||||
#
|
||||
# For more info:
|
||||
# https://docs.microsoft.com/en-us/azure/azure-signalr/signalr-concept-internals
|
||||
|
||||
|
||||
def main(req: func.HttpRequest, connectionInfoJson: str) -> func.HttpResponse:
|
||||
# NOTE: this is a sub-method because the call_if* do not support callbacks with
|
||||
# additional arguments at this time. Once call_if* supports additional arguments,
|
||||
# this should be made a generic function
|
||||
def post(req: func.HttpRequest) -> func.HttpResponse:
|
||||
return func.HttpResponse(
|
||||
connectionInfoJson,
|
||||
status_code=200,
|
||||
headers={"Content-type": "application/json"},
|
||||
)
|
||||
|
||||
methods = {"POST": post}
|
||||
method = methods[req.method]
|
||||
|
||||
result = call_if_user(req, method)
|
||||
|
||||
return result
|
@ -1,25 +0,0 @@
|
||||
{
|
||||
"scriptFile": "__init__.py",
|
||||
"bindings": [
|
||||
{
|
||||
"authLevel": "anonymous",
|
||||
"type": "httpTrigger",
|
||||
"direction": "in",
|
||||
"name": "req",
|
||||
"methods": [
|
||||
"post"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "http",
|
||||
"direction": "out",
|
||||
"name": "$return"
|
||||
},
|
||||
{
|
||||
"type": "signalRConnectionInfo",
|
||||
"direction": "in",
|
||||
"name": "connectionInfoJson",
|
||||
"hubName": "dashboard"
|
||||
}
|
||||
]
|
||||
}
|
@ -1,122 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import azure.functions as func
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.requests import NodeGet, NodeSearch, NodeUpdate
|
||||
from onefuzztypes.responses import BoolResult
|
||||
|
||||
from ..onefuzzlib.endpoint_authorization import call_if_user, check_require_admins
|
||||
from ..onefuzzlib.request import not_ok, ok, parse_request
|
||||
from ..onefuzzlib.workers.nodes import Node, NodeMessage, NodeTasks
|
||||
|
||||
|
||||
def get(req: func.HttpRequest) -> func.HttpResponse:
|
||||
request = parse_request(NodeSearch, req)
|
||||
if isinstance(request, Error):
|
||||
return not_ok(request, context="pool get")
|
||||
|
||||
if request.machine_id:
|
||||
node = Node.get_by_machine_id(request.machine_id)
|
||||
if not node:
|
||||
return not_ok(
|
||||
Error(code=ErrorCode.UNABLE_TO_FIND, errors=["unable to find node"]),
|
||||
context=request.machine_id,
|
||||
)
|
||||
|
||||
if isinstance(node, Error):
|
||||
return not_ok(node, context=request.machine_id)
|
||||
|
||||
node.tasks = [n for n in NodeTasks.get_by_machine_id(request.machine_id)]
|
||||
node.messages = [
|
||||
x.message for x in NodeMessage.get_messages(request.machine_id)
|
||||
]
|
||||
|
||||
return ok(node)
|
||||
|
||||
nodes = Node.search_states(
|
||||
states=request.state,
|
||||
pool_name=request.pool_name,
|
||||
scaleset_id=request.scaleset_id,
|
||||
)
|
||||
return ok(nodes)
|
||||
|
||||
|
||||
def post(req: func.HttpRequest) -> func.HttpResponse:
|
||||
request = parse_request(NodeUpdate, req)
|
||||
if isinstance(request, Error):
|
||||
return not_ok(request, context="NodeUpdate")
|
||||
|
||||
answer = check_require_admins(req)
|
||||
if isinstance(answer, Error):
|
||||
return not_ok(answer, context="NodeUpdate")
|
||||
|
||||
node = Node.get_by_machine_id(request.machine_id)
|
||||
if not node:
|
||||
return not_ok(
|
||||
Error(code=ErrorCode.UNABLE_TO_FIND, errors=["unable to find node"]),
|
||||
context=request.machine_id,
|
||||
)
|
||||
if request.debug_keep_node is not None:
|
||||
node.debug_keep_node = request.debug_keep_node
|
||||
|
||||
node.save()
|
||||
return ok(BoolResult(result=True))
|
||||
|
||||
|
||||
def delete(req: func.HttpRequest) -> func.HttpResponse:
|
||||
request = parse_request(NodeGet, req)
|
||||
if isinstance(request, Error):
|
||||
return not_ok(request, context="NodeDelete")
|
||||
|
||||
answer = check_require_admins(req)
|
||||
if isinstance(answer, Error):
|
||||
return not_ok(answer, context="NodeDelete")
|
||||
|
||||
node = Node.get_by_machine_id(request.machine_id)
|
||||
if not node:
|
||||
return not_ok(
|
||||
Error(code=ErrorCode.UNABLE_TO_FIND, errors=["unable to find node"]),
|
||||
context=request.machine_id,
|
||||
)
|
||||
|
||||
node.set_halt()
|
||||
if node.debug_keep_node:
|
||||
node.debug_keep_node = False
|
||||
node.save()
|
||||
|
||||
return ok(BoolResult(result=True))
|
||||
|
||||
|
||||
def patch(req: func.HttpRequest) -> func.HttpResponse:
|
||||
request = parse_request(NodeGet, req)
|
||||
if isinstance(request, Error):
|
||||
return not_ok(request, context="NodeReimage")
|
||||
|
||||
answer = check_require_admins(req)
|
||||
if isinstance(answer, Error):
|
||||
return not_ok(answer, context="NodeReimage")
|
||||
|
||||
node = Node.get_by_machine_id(request.machine_id)
|
||||
if not node:
|
||||
return not_ok(
|
||||
Error(code=ErrorCode.UNABLE_TO_FIND, errors=["unable to find node"]),
|
||||
context=request.machine_id,
|
||||
)
|
||||
|
||||
node.stop(done=True)
|
||||
if node.debug_keep_node:
|
||||
node.debug_keep_node = False
|
||||
node.save()
|
||||
return ok(BoolResult(result=True))
|
||||
|
||||
|
||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
||||
methods = {"GET": get, "PATCH": patch, "DELETE": delete, "POST": post}
|
||||
method = methods[req.method]
|
||||
result = call_if_user(req, method)
|
||||
|
||||
return result
|
@ -1,22 +0,0 @@
|
||||
{
|
||||
"scriptFile": "__init__.py",
|
||||
"bindings": [
|
||||
{
|
||||
"authLevel": "anonymous",
|
||||
"type": "httpTrigger",
|
||||
"direction": "in",
|
||||
"name": "req",
|
||||
"methods": [
|
||||
"get",
|
||||
"patch",
|
||||
"delete",
|
||||
"post"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "http",
|
||||
"direction": "out",
|
||||
"name": "$return"
|
||||
}
|
||||
]
|
||||
}
|
@ -1,40 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import azure.functions as func
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.requests import NodeAddSshKey
|
||||
from onefuzztypes.responses import BoolResult
|
||||
|
||||
from ..onefuzzlib.endpoint_authorization import call_if_user
|
||||
from ..onefuzzlib.request import not_ok, ok, parse_request
|
||||
from ..onefuzzlib.workers.nodes import Node
|
||||
|
||||
|
||||
def post(req: func.HttpRequest) -> func.HttpResponse:
|
||||
request = parse_request(NodeAddSshKey, req)
|
||||
if isinstance(request, Error):
|
||||
return not_ok(request, context="NodeAddSshKey")
|
||||
|
||||
node = Node.get_by_machine_id(request.machine_id)
|
||||
if not node:
|
||||
return not_ok(
|
||||
Error(code=ErrorCode.UNABLE_TO_FIND, errors=["unable to find node"]),
|
||||
context=request.machine_id,
|
||||
)
|
||||
result = node.add_ssh_public_key(public_key=request.public_key)
|
||||
if isinstance(result, Error):
|
||||
return not_ok(result, context="NodeAddSshKey")
|
||||
|
||||
return ok(BoolResult(result=True))
|
||||
|
||||
|
||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
||||
methods = {"POST": post}
|
||||
method = methods[req.method]
|
||||
result = call_if_user(req, method)
|
||||
|
||||
return result
|
@ -1,20 +0,0 @@
|
||||
{
|
||||
"scriptFile": "__init__.py",
|
||||
"bindings": [
|
||||
{
|
||||
"authLevel": "anonymous",
|
||||
"type": "httpTrigger",
|
||||
"direction": "in",
|
||||
"name": "req",
|
||||
"methods": [
|
||||
"post"
|
||||
],
|
||||
"route": "node/add_ssh_key"
|
||||
},
|
||||
{
|
||||
"type": "http",
|
||||
"direction": "out",
|
||||
"name": "$return"
|
||||
}
|
||||
]
|
||||
}
|
@ -1,70 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
|
||||
import azure.functions as func
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.requests import (
|
||||
NotificationCreate,
|
||||
NotificationGet,
|
||||
NotificationSearch,
|
||||
)
|
||||
|
||||
from ..onefuzzlib.endpoint_authorization import call_if_user
|
||||
from ..onefuzzlib.notifications.main import Notification
|
||||
from ..onefuzzlib.request import not_ok, ok, parse_request
|
||||
|
||||
|
||||
def get(req: func.HttpRequest) -> func.HttpResponse:
|
||||
logging.info("notification search")
|
||||
request = parse_request(NotificationSearch, req)
|
||||
if isinstance(request, Error):
|
||||
return not_ok(request, context="notification search")
|
||||
|
||||
if request.container:
|
||||
entries = Notification.search(query={"container": request.container})
|
||||
else:
|
||||
entries = Notification.search()
|
||||
|
||||
return ok(entries)
|
||||
|
||||
|
||||
def post(req: func.HttpRequest) -> func.HttpResponse:
|
||||
logging.info("adding notification hook")
|
||||
request = parse_request(NotificationCreate, req)
|
||||
if isinstance(request, Error):
|
||||
return not_ok(request, context="notification create")
|
||||
|
||||
entry = Notification.create(
|
||||
container=request.container,
|
||||
config=request.config,
|
||||
replace_existing=request.replace_existing,
|
||||
)
|
||||
if isinstance(entry, Error):
|
||||
return not_ok(entry, context="notification create")
|
||||
|
||||
return ok(entry)
|
||||
|
||||
|
||||
def delete(req: func.HttpRequest) -> func.HttpResponse:
|
||||
request = parse_request(NotificationGet, req)
|
||||
if isinstance(request, Error):
|
||||
return not_ok(request, context="notification delete")
|
||||
|
||||
entry = Notification.get_by_id(request.notification_id)
|
||||
if isinstance(entry, Error):
|
||||
return not_ok(entry, context="notification delete")
|
||||
|
||||
entry.delete()
|
||||
return ok(entry)
|
||||
|
||||
|
||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
||||
methods = {"GET": get, "POST": post, "DELETE": delete}
|
||||
method = methods[req.method]
|
||||
result = call_if_user(req, method)
|
||||
|
||||
return result
|
@ -1,21 +0,0 @@
|
||||
{
|
||||
"scriptFile": "__init__.py",
|
||||
"bindings": [
|
||||
{
|
||||
"authLevel": "anonymous",
|
||||
"type": "httpTrigger",
|
||||
"direction": "in",
|
||||
"name": "req",
|
||||
"methods": [
|
||||
"get",
|
||||
"post",
|
||||
"delete"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "http",
|
||||
"direction": "out",
|
||||
"name": "$return"
|
||||
}
|
||||
]
|
||||
}
|
@ -1,5 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: disable=W0612,C0111
|
||||
__version__ = "0.0.0"
|
@ -1,271 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
from typing import Optional, cast
|
||||
from uuid import UUID
|
||||
|
||||
from onefuzztypes.enums import (
|
||||
ErrorCode,
|
||||
NodeState,
|
||||
NodeTaskState,
|
||||
TaskDebugFlag,
|
||||
TaskState,
|
||||
)
|
||||
from onefuzztypes.models import (
|
||||
Error,
|
||||
NodeDoneEventData,
|
||||
NodeSettingUpEventData,
|
||||
NodeStateUpdate,
|
||||
Result,
|
||||
WorkerDoneEvent,
|
||||
WorkerEvent,
|
||||
WorkerRunningEvent,
|
||||
)
|
||||
|
||||
from .task_event import TaskEvent
|
||||
from .tasks.main import Task
|
||||
from .workers.nodes import Node, NodeTasks
|
||||
|
||||
MAX_OUTPUT_SIZE = 4096
|
||||
|
||||
|
||||
def get_node(machine_id: UUID) -> Result[Node]:
|
||||
node = Node.get_by_machine_id(machine_id)
|
||||
if not node:
|
||||
return Error(code=ErrorCode.INVALID_NODE, errors=["unable to find node"])
|
||||
return node
|
||||
|
||||
|
||||
def on_state_update(
|
||||
machine_id: UUID,
|
||||
state_update: NodeStateUpdate,
|
||||
) -> Result[None]:
|
||||
state = state_update.state
|
||||
node = get_node(machine_id)
|
||||
if isinstance(node, Error):
|
||||
if state == NodeState.done:
|
||||
logging.warning(
|
||||
"unable to process state update event. machine_id:"
|
||||
f"{machine_id} state event:{state_update} error:{node}"
|
||||
)
|
||||
return None
|
||||
return node
|
||||
|
||||
if state == NodeState.free:
|
||||
if node.reimage_requested or node.delete_requested:
|
||||
logging.info("stopping free node with reset flags: %s", node.machine_id)
|
||||
node.stop()
|
||||
return None
|
||||
|
||||
if node.could_shrink_scaleset():
|
||||
logging.info("stopping free node to resize scaleset: %s", node.machine_id)
|
||||
node.set_halt()
|
||||
return None
|
||||
|
||||
if state == NodeState.init:
|
||||
if node.delete_requested:
|
||||
logging.info("stopping node (init and delete_requested): %s", machine_id)
|
||||
node.stop()
|
||||
return None
|
||||
|
||||
# not checking reimage_requested, as nodes only send 'init' state once. If
|
||||
# they send 'init' with reimage_requested, it's because the node was reimaged
|
||||
# successfully.
|
||||
node.reimage_requested = False
|
||||
node.initialized_at = datetime.datetime.now(datetime.timezone.utc)
|
||||
node.set_state(state)
|
||||
|
||||
return None
|
||||
|
||||
logging.info("node state update: %s from:%s to:%s", machine_id, node.state, state)
|
||||
node.set_state(state)
|
||||
|
||||
if state == NodeState.free:
|
||||
logging.info("node now available for work: %s", machine_id)
|
||||
elif state == NodeState.setting_up:
|
||||
# Model-validated.
|
||||
#
|
||||
# This field will be required in the future.
|
||||
# For now, it is optional for back compat.
|
||||
setting_up_data = cast(
|
||||
Optional[NodeSettingUpEventData],
|
||||
state_update.data,
|
||||
)
|
||||
|
||||
if setting_up_data:
|
||||
if not setting_up_data.tasks:
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
errors=["setup without tasks. machine_id: %s", str(machine_id)],
|
||||
)
|
||||
|
||||
for task_id in setting_up_data.tasks:
|
||||
task = Task.get_by_task_id(task_id)
|
||||
if isinstance(task, Error):
|
||||
return task
|
||||
|
||||
logging.info(
|
||||
"node starting task. machine_id: %s job_id: %s task_id: %s",
|
||||
machine_id,
|
||||
task.job_id,
|
||||
task.task_id,
|
||||
)
|
||||
|
||||
# The task state may be `running` if it has `vm_count` > 1, and
|
||||
# another node is concurrently executing the task. If so, leave
|
||||
# the state as-is, to represent the max progress made.
|
||||
#
|
||||
# Other states we would want to preserve are excluded by the
|
||||
# outermost conditional check.
|
||||
if task.state not in [TaskState.running, TaskState.setting_up]:
|
||||
task.set_state(TaskState.setting_up)
|
||||
|
||||
# Note: we set the node task state to `setting_up`, even though
|
||||
# the task itself may be `running`.
|
||||
node_task = NodeTasks(
|
||||
machine_id=machine_id,
|
||||
task_id=task_id,
|
||||
state=NodeTaskState.setting_up,
|
||||
)
|
||||
node_task.save()
|
||||
|
||||
elif state == NodeState.done:
|
||||
# Model-validated.
|
||||
#
|
||||
# This field will be required in the future.
|
||||
# For now, it is optional for back compat.
|
||||
done_data = cast(Optional[NodeDoneEventData], state_update.data)
|
||||
error = None
|
||||
if done_data:
|
||||
if done_data.error:
|
||||
error_text = done_data.json(exclude_none=True)
|
||||
error = Error(
|
||||
code=ErrorCode.TASK_FAILED,
|
||||
errors=[error_text],
|
||||
)
|
||||
logging.error(
|
||||
"node 'done' with error: machine_id:%s, data:%s",
|
||||
machine_id,
|
||||
error_text,
|
||||
)
|
||||
|
||||
# if tasks are running on the node when it reports as Done
|
||||
# those are stopped early
|
||||
node.mark_tasks_stopped_early(error=error)
|
||||
node.to_reimage(done=True)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def on_worker_event_running(
|
||||
machine_id: UUID, event: WorkerRunningEvent
|
||||
) -> Result[None]:
|
||||
task = Task.get_by_task_id(event.task_id)
|
||||
if isinstance(task, Error):
|
||||
return task
|
||||
|
||||
node = get_node(machine_id)
|
||||
if isinstance(node, Error):
|
||||
return node
|
||||
|
||||
if node.state not in NodeState.ready_for_reset():
|
||||
node.set_state(NodeState.busy)
|
||||
|
||||
node_task = NodeTasks(
|
||||
machine_id=machine_id, task_id=event.task_id, state=NodeTaskState.running
|
||||
)
|
||||
node_task.save()
|
||||
|
||||
if task.state in TaskState.shutting_down():
|
||||
logging.info(
|
||||
"ignoring task start from node. "
|
||||
"machine_id:%s job_id:%s task_id:%s (state: %s)",
|
||||
machine_id,
|
||||
task.job_id,
|
||||
task.task_id,
|
||||
task.state,
|
||||
)
|
||||
return None
|
||||
|
||||
logging.info(
|
||||
"task started on node. machine_id:%s job_id%s task_id:%s",
|
||||
machine_id,
|
||||
task.job_id,
|
||||
task.task_id,
|
||||
)
|
||||
task.set_state(TaskState.running)
|
||||
|
||||
task_event = TaskEvent(
|
||||
task_id=task.task_id,
|
||||
machine_id=machine_id,
|
||||
event_data=WorkerEvent(running=event),
|
||||
)
|
||||
task_event.save()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def on_worker_event_done(machine_id: UUID, event: WorkerDoneEvent) -> Result[None]:
|
||||
task = Task.get_by_task_id(event.task_id)
|
||||
if isinstance(task, Error):
|
||||
return task
|
||||
|
||||
node = get_node(machine_id)
|
||||
if isinstance(node, Error):
|
||||
return node
|
||||
|
||||
if event.exit_status.success:
|
||||
logging.info(
|
||||
"task done. %s:%s status:%s", task.job_id, task.task_id, event.exit_status
|
||||
)
|
||||
task.mark_stopping()
|
||||
if (
|
||||
task.config.debug
|
||||
and TaskDebugFlag.keep_node_on_completion in task.config.debug
|
||||
):
|
||||
node.debug_keep_node = True
|
||||
node.save()
|
||||
else:
|
||||
task.mark_failed(
|
||||
Error(
|
||||
code=ErrorCode.TASK_FAILED,
|
||||
errors=[
|
||||
"task failed. exit_status:%s" % event.exit_status,
|
||||
event.stdout[-MAX_OUTPUT_SIZE:],
|
||||
event.stderr[-MAX_OUTPUT_SIZE:],
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
if task.config.debug and (
|
||||
TaskDebugFlag.keep_node_on_failure in task.config.debug
|
||||
or TaskDebugFlag.keep_node_on_completion in task.config.debug
|
||||
):
|
||||
node.debug_keep_node = True
|
||||
node.save()
|
||||
|
||||
if not node.debug_keep_node:
|
||||
node_task = NodeTasks.get(machine_id, event.task_id)
|
||||
if node_task:
|
||||
node_task.delete()
|
||||
|
||||
event.stdout = event.stdout[-MAX_OUTPUT_SIZE:]
|
||||
event.stderr = event.stderr[-MAX_OUTPUT_SIZE:]
|
||||
task_event = TaskEvent(
|
||||
task_id=task.task_id, machine_id=machine_id, event_data=WorkerEvent(done=event)
|
||||
)
|
||||
task_event.save()
|
||||
return None
|
||||
|
||||
|
||||
def on_worker_event(machine_id: UUID, event: WorkerEvent) -> Result[None]:
|
||||
if event.running:
|
||||
return on_worker_event_running(machine_id, event.running)
|
||||
elif event.done:
|
||||
return on_worker_event_done(machine_id, event.done)
|
||||
else:
|
||||
raise NotImplementedError
|
@ -1,171 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
from onefuzztypes.enums import NodeState, ScalesetState
|
||||
from onefuzztypes.models import AutoScaleConfig, TaskPool
|
||||
|
||||
from .tasks.main import Task
|
||||
from .workers.nodes import Node
|
||||
from .workers.pools import Pool
|
||||
from .workers.scalesets import Scaleset
|
||||
|
||||
|
||||
def scale_up(pool: Pool, scalesets: List[Scaleset], nodes_needed: int) -> None:
|
||||
logging.info("Scaling up")
|
||||
autoscale_config = pool.autoscale
|
||||
if not isinstance(autoscale_config, AutoScaleConfig):
|
||||
return
|
||||
|
||||
for scaleset in scalesets:
|
||||
if scaleset.state in [ScalesetState.running, ScalesetState.resize]:
|
||||
|
||||
max_size = min(scaleset.max_size(), autoscale_config.scaleset_size)
|
||||
logging.info(
|
||||
"scaleset:%s size:%d max_size:%d",
|
||||
scaleset.scaleset_id,
|
||||
scaleset.size,
|
||||
max_size,
|
||||
)
|
||||
if scaleset.size < max_size:
|
||||
current_size = scaleset.size
|
||||
if nodes_needed <= max_size - current_size:
|
||||
scaleset.set_size(current_size + nodes_needed)
|
||||
nodes_needed = 0
|
||||
else:
|
||||
scaleset.set_size(max_size)
|
||||
nodes_needed = nodes_needed - (max_size - current_size)
|
||||
|
||||
else:
|
||||
continue
|
||||
|
||||
if nodes_needed == 0:
|
||||
return
|
||||
|
||||
for _ in range(
|
||||
math.ceil(
|
||||
nodes_needed
|
||||
/ min(
|
||||
Scaleset.scaleset_max_size(autoscale_config.image),
|
||||
autoscale_config.scaleset_size,
|
||||
)
|
||||
)
|
||||
):
|
||||
logging.info("Creating Scaleset for Pool %s", pool.name)
|
||||
max_nodes_scaleset = min(
|
||||
Scaleset.scaleset_max_size(autoscale_config.image),
|
||||
autoscale_config.scaleset_size,
|
||||
nodes_needed,
|
||||
)
|
||||
|
||||
if not autoscale_config.region:
|
||||
raise Exception("Region is missing")
|
||||
|
||||
Scaleset.create(
|
||||
pool_name=pool.name,
|
||||
vm_sku=autoscale_config.vm_sku,
|
||||
image=autoscale_config.image,
|
||||
region=autoscale_config.region,
|
||||
size=max_nodes_scaleset,
|
||||
spot_instances=autoscale_config.spot_instances,
|
||||
ephemeral_os_disks=autoscale_config.ephemeral_os_disks,
|
||||
tags={"pool": pool.name},
|
||||
)
|
||||
nodes_needed -= max_nodes_scaleset
|
||||
|
||||
|
||||
def scale_down(scalesets: List[Scaleset], nodes_to_remove: int) -> None:
|
||||
logging.info("Scaling down")
|
||||
for scaleset in scalesets:
|
||||
num_of_nodes = len(Node.search_states(scaleset_id=scaleset.scaleset_id))
|
||||
if scaleset.size != num_of_nodes and scaleset.state not in [
|
||||
ScalesetState.resize,
|
||||
ScalesetState.shutdown,
|
||||
ScalesetState.halt,
|
||||
]:
|
||||
scaleset.set_state(ScalesetState.resize)
|
||||
|
||||
free_nodes = Node.search_states(
|
||||
scaleset_id=scaleset.scaleset_id,
|
||||
states=[NodeState.free],
|
||||
)
|
||||
nodes = []
|
||||
for node in free_nodes:
|
||||
if not node.delete_requested:
|
||||
nodes.append(node)
|
||||
logging.info("Scaleset: %s, #Free Nodes: %s", scaleset.scaleset_id, len(nodes))
|
||||
|
||||
if nodes and nodes_to_remove > 0:
|
||||
max_nodes_remove = min(len(nodes), nodes_to_remove)
|
||||
# All nodes in scaleset are free. Can shutdown VMSS
|
||||
if max_nodes_remove >= scaleset.size and len(nodes) >= scaleset.size:
|
||||
scaleset.set_state(ScalesetState.shutdown)
|
||||
nodes_to_remove = nodes_to_remove - scaleset.size
|
||||
for node in nodes:
|
||||
node.set_shutdown()
|
||||
continue
|
||||
|
||||
# Resize of VMSS needed
|
||||
scaleset.set_size(scaleset.size - max_nodes_remove)
|
||||
nodes_to_remove = nodes_to_remove - max_nodes_remove
|
||||
scaleset.set_state(ScalesetState.resize)
|
||||
|
||||
|
||||
def get_vm_count(tasks: List[Task]) -> int:
|
||||
count = 0
|
||||
for task in tasks:
|
||||
task_pool = task.get_pool()
|
||||
if (
|
||||
not task_pool
|
||||
or not isinstance(task_pool, Pool)
|
||||
or not isinstance(task.config.pool, TaskPool)
|
||||
):
|
||||
continue
|
||||
count += task.config.pool.count
|
||||
return count
|
||||
|
||||
|
||||
def autoscale_pool(pool: Pool) -> None:
|
||||
logging.info("autoscale: %s", pool.autoscale)
|
||||
if not pool.autoscale:
|
||||
return
|
||||
|
||||
# get all the tasks (count not stopped) for the pool
|
||||
tasks = Task.get_tasks_by_pool_name(pool.name)
|
||||
logging.info("Pool: %s, #Tasks %d", pool.name, len(tasks))
|
||||
|
||||
num_of_tasks = get_vm_count(tasks)
|
||||
nodes_needed = max(num_of_tasks, pool.autoscale.min_size)
|
||||
if pool.autoscale.max_size:
|
||||
nodes_needed = min(nodes_needed, pool.autoscale.max_size)
|
||||
|
||||
# do scaleset logic match with pool
|
||||
# get all the scalesets for the pool
|
||||
scalesets = Scaleset.search_by_pool(pool.name)
|
||||
pool_resize = False
|
||||
for scaleset in scalesets:
|
||||
if scaleset.state in ScalesetState.modifying():
|
||||
pool_resize = True
|
||||
break
|
||||
nodes_needed = nodes_needed - scaleset.size
|
||||
|
||||
if pool_resize:
|
||||
return
|
||||
|
||||
logging.info("Pool: %s, #Nodes Needed: %d", pool.name, nodes_needed)
|
||||
if nodes_needed > 0:
|
||||
# resizing scaleset or creating new scaleset.
|
||||
scale_up(pool, scalesets, nodes_needed)
|
||||
elif nodes_needed < 0:
|
||||
for scaleset in scalesets:
|
||||
nodes = Node.search_states(scaleset_id=scaleset.scaleset_id)
|
||||
for node in nodes:
|
||||
if node.delete_requested:
|
||||
nodes_needed += 1
|
||||
if nodes_needed < 0:
|
||||
scale_down(scalesets, abs(nodes_needed))
|
@ -1,37 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import subprocess # nosec - used for ssh key generation
|
||||
import tempfile
|
||||
from typing import Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
from onefuzztypes.models import Authentication
|
||||
|
||||
|
||||
def generate_keypair() -> Tuple[str, str]:
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
filename = os.path.join(tmpdir, "key")
|
||||
|
||||
cmd = ["ssh-keygen", "-t", "rsa", "-f", filename, "-P", "", "-b", "2048"]
|
||||
subprocess.check_output(cmd) # nosec - all arguments are under our control
|
||||
|
||||
with open(filename, "r") as handle:
|
||||
private = handle.read()
|
||||
|
||||
with open(filename + ".pub", "r") as handle:
|
||||
public = handle.read().strip()
|
||||
|
||||
return (public, private)
|
||||
|
||||
|
||||
def build_auth() -> Authentication:
|
||||
public_key, private_key = generate_keypair()
|
||||
auth = Authentication(
|
||||
password=str(uuid4()), public_key=public_key, private_key=private_key
|
||||
)
|
||||
return auth
|
@ -1,328 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from azure.mgmt.monitor.models import (
|
||||
AutoscaleProfile,
|
||||
AutoscaleSettingResource,
|
||||
ComparisonOperationType,
|
||||
DiagnosticSettingsResource,
|
||||
LogSettings,
|
||||
MetricStatisticType,
|
||||
MetricTrigger,
|
||||
RetentionPolicy,
|
||||
ScaleAction,
|
||||
ScaleCapacity,
|
||||
ScaleDirection,
|
||||
ScaleRule,
|
||||
ScaleType,
|
||||
TimeAggregationType,
|
||||
)
|
||||
from msrestazure.azure_exceptions import CloudError
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.primitives import Region
|
||||
|
||||
from .creds import (
|
||||
get_base_region,
|
||||
get_base_resource_group,
|
||||
get_subscription,
|
||||
retry_on_auth_failure,
|
||||
)
|
||||
from .log_analytics import get_workspace_id
|
||||
from .monitor import get_monitor_client
|
||||
|
||||
|
||||
@retry_on_auth_failure()
|
||||
def get_auto_scale_settings(
|
||||
vmss: UUID,
|
||||
) -> Union[Optional[AutoscaleSettingResource], Error]:
|
||||
logging.info("Getting auto scale settings for %s" % vmss)
|
||||
client = get_monitor_client()
|
||||
resource_group = get_base_resource_group()
|
||||
|
||||
try:
|
||||
auto_scale_collections = client.autoscale_settings.list_by_resource_group(
|
||||
resource_group
|
||||
)
|
||||
for auto_scale in auto_scale_collections:
|
||||
if str(auto_scale.target_resource_uri).endswith(str(vmss)):
|
||||
logging.info("Found auto scale settings for %s" % vmss)
|
||||
return auto_scale
|
||||
|
||||
except (ResourceNotFoundError, CloudError):
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_CONFIGURATION,
|
||||
errors=[
|
||||
"Failed to check if scaleset %s already has an autoscale resource"
|
||||
% vmss
|
||||
],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@retry_on_auth_failure()
|
||||
def add_auto_scale_to_vmss(
|
||||
vmss: UUID, auto_scale_profile: AutoscaleProfile
|
||||
) -> Optional[Error]:
|
||||
logging.info("Checking scaleset %s for existing auto scale resources" % vmss)
|
||||
|
||||
existing_auto_scale_resource = get_auto_scale_settings(vmss)
|
||||
|
||||
if isinstance(existing_auto_scale_resource, Error):
|
||||
return existing_auto_scale_resource
|
||||
|
||||
if existing_auto_scale_resource is not None:
|
||||
logging.warning("Scaleset %s already has auto scale resource" % vmss)
|
||||
return None
|
||||
|
||||
auto_scale_resource = create_auto_scale_resource_for(
|
||||
vmss, get_base_region(), auto_scale_profile
|
||||
)
|
||||
if isinstance(auto_scale_resource, Error):
|
||||
return auto_scale_resource
|
||||
|
||||
diagnostics_resource = setup_auto_scale_diagnostics(
|
||||
auto_scale_resource.id, auto_scale_resource.name, get_workspace_id()
|
||||
)
|
||||
if isinstance(diagnostics_resource, Error):
|
||||
return diagnostics_resource
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def update_auto_scale(auto_scale_resource: AutoscaleSettingResource) -> Optional[Error]:
|
||||
logging.info("Updating auto scale resource: %s" % auto_scale_resource.name)
|
||||
client = get_monitor_client()
|
||||
resource_group = get_base_resource_group()
|
||||
|
||||
try:
|
||||
auto_scale_resource = client.autoscale_settings.create_or_update(
|
||||
resource_group, auto_scale_resource.name, auto_scale_resource
|
||||
)
|
||||
logging.info(
|
||||
"Successfully updated auto scale resource: %s" % auto_scale_resource.name
|
||||
)
|
||||
except (ResourceNotFoundError, CloudError):
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_UPDATE,
|
||||
errors=[
|
||||
"unable to update auto scale resource with name: %s and profile: %s"
|
||||
% (auto_scale_resource.name, auto_scale_resource)
|
||||
],
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def create_auto_scale_resource_for(
|
||||
resource_id: UUID, location: Region, profile: AutoscaleProfile
|
||||
) -> Union[AutoscaleSettingResource, Error]:
|
||||
logging.info("Creating auto scale resource for: %s" % resource_id)
|
||||
client = get_monitor_client()
|
||||
resource_group = get_base_resource_group()
|
||||
subscription = get_subscription()
|
||||
|
||||
scaleset_uri = (
|
||||
"/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachineScaleSets/%s" # noqa: E501
|
||||
% (subscription, resource_group, resource_id)
|
||||
)
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
"location": location,
|
||||
"profiles": [profile],
|
||||
"target_resource_uri": scaleset_uri,
|
||||
"enabled": True,
|
||||
}
|
||||
|
||||
try:
|
||||
auto_scale_resource = client.autoscale_settings.create_or_update(
|
||||
resource_group, str(uuid.uuid4()), params
|
||||
)
|
||||
logging.info(
|
||||
"Successfully created auto scale resource %s for %s"
|
||||
% (auto_scale_resource.id, resource_id)
|
||||
)
|
||||
return auto_scale_resource
|
||||
except (ResourceNotFoundError, CloudError):
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_CREATE,
|
||||
errors=[
|
||||
"unable to create auto scale resource for resource: %s with profile: %s"
|
||||
% (resource_id, profile)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def create_auto_scale_profile(
|
||||
queue_uri: str,
|
||||
min: int,
|
||||
max: int,
|
||||
default: int,
|
||||
scale_out_amount: int,
|
||||
scale_out_cooldown: int,
|
||||
scale_in_amount: int,
|
||||
scale_in_cooldown: int,
|
||||
) -> AutoscaleProfile:
|
||||
return AutoscaleProfile(
|
||||
name=str(uuid.uuid4()),
|
||||
capacity=ScaleCapacity(minimum=min, maximum=max, default=max),
|
||||
# Auto scale tuning guidance:
|
||||
# https://docs.microsoft.com/en-us/azure/architecture/best-practices/auto-scaling
|
||||
rules=[
|
||||
ScaleRule(
|
||||
metric_trigger=MetricTrigger(
|
||||
metric_name="ApproximateMessageCount",
|
||||
metric_resource_uri=queue_uri,
|
||||
# Check every 15 minutes
|
||||
time_grain=timedelta(minutes=15),
|
||||
# The average amount of messages there are in the pool queue
|
||||
time_aggregation=TimeAggregationType.AVERAGE,
|
||||
statistic=MetricStatisticType.COUNT,
|
||||
# Over the past 15 minutes
|
||||
time_window=timedelta(minutes=15),
|
||||
# When there's more than 1 message in the pool queue
|
||||
operator=ComparisonOperationType.GREATER_THAN_OR_EQUAL,
|
||||
threshold=1,
|
||||
divide_per_instance=False,
|
||||
),
|
||||
scale_action=ScaleAction(
|
||||
direction=ScaleDirection.INCREASE,
|
||||
type=ScaleType.CHANGE_COUNT,
|
||||
value=scale_out_amount,
|
||||
cooldown=timedelta(minutes=scale_out_cooldown),
|
||||
),
|
||||
),
|
||||
# Scale in
|
||||
ScaleRule(
|
||||
# Scale in if no work in the past 20 mins
|
||||
metric_trigger=MetricTrigger(
|
||||
metric_name="ApproximateMessageCount",
|
||||
metric_resource_uri=queue_uri,
|
||||
# Check every 10 minutes
|
||||
time_grain=timedelta(minutes=10),
|
||||
# The average amount of messages there are in the pool queue
|
||||
time_aggregation=TimeAggregationType.AVERAGE,
|
||||
statistic=MetricStatisticType.SUM,
|
||||
# Over the past 10 minutes
|
||||
time_window=timedelta(minutes=10),
|
||||
# When there's no messages in the pool queue
|
||||
operator=ComparisonOperationType.EQUALS,
|
||||
threshold=0,
|
||||
divide_per_instance=False,
|
||||
),
|
||||
scale_action=ScaleAction(
|
||||
direction=ScaleDirection.DECREASE,
|
||||
type=ScaleType.CHANGE_COUNT,
|
||||
value=scale_in_amount,
|
||||
cooldown=timedelta(minutes=scale_in_cooldown),
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def get_auto_scale_profile(scaleset_id: UUID) -> AutoscaleProfile:
|
||||
logging.info("Getting scaleset %s for existing auto scale resources" % scaleset_id)
|
||||
client = get_monitor_client()
|
||||
resource_group = get_base_resource_group()
|
||||
|
||||
auto_scale_resource = None
|
||||
|
||||
try:
|
||||
auto_scale_collections = client.autoscale_settings.list_by_resource_group(
|
||||
resource_group
|
||||
)
|
||||
for auto_scale in auto_scale_collections:
|
||||
if str(auto_scale.target_resource_uri).endswith(str(scaleset_id)):
|
||||
auto_scale_resource = auto_scale
|
||||
auto_scale_profiles = auto_scale_resource.profiles
|
||||
if len(auto_scale_profiles) != 1:
|
||||
logging.info(
|
||||
"Found more than one autoscaling profile for scaleset %s"
|
||||
% scaleset_id
|
||||
)
|
||||
return auto_scale_profiles[0]
|
||||
|
||||
except (ResourceNotFoundError, CloudError):
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_CONFIGURATION,
|
||||
errors=["Failed to query scaleset %s autoscale resource" % scaleset_id],
|
||||
)
|
||||
|
||||
|
||||
def default_auto_scale_profile(queue_uri: str, scaleset_size: int) -> AutoscaleProfile:
|
||||
return create_auto_scale_profile(
|
||||
queue_uri, 1, scaleset_size, scaleset_size, 1, 10, 1, 5
|
||||
)
|
||||
|
||||
|
||||
def shutdown_scaleset_rule(queue_uri: str) -> ScaleRule:
|
||||
return ScaleRule(
|
||||
# Scale in if there are 0 or more messages in the queue (aka: every time)
|
||||
metric_trigger=MetricTrigger(
|
||||
metric_name="ApproximateMessageCount",
|
||||
metric_resource_uri=queue_uri,
|
||||
# Check every 10 minutes
|
||||
time_grain=timedelta(minutes=5),
|
||||
# The average amount of messages there are in the pool queue
|
||||
time_aggregation=TimeAggregationType.AVERAGE,
|
||||
statistic=MetricStatisticType.SUM,
|
||||
# Over the past 10 minutes
|
||||
time_window=timedelta(minutes=5),
|
||||
operator=ComparisonOperationType.GREATER_THAN_OR_EQUAL,
|
||||
threshold=0,
|
||||
divide_per_instance=False,
|
||||
),
|
||||
scale_action=ScaleAction(
|
||||
direction=ScaleDirection.DECREASE,
|
||||
type=ScaleType.CHANGE_COUNT,
|
||||
value=1,
|
||||
cooldown=timedelta(minutes=5),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def setup_auto_scale_diagnostics(
|
||||
auto_scale_resource_uri: str,
|
||||
auto_scale_resource_name: str,
|
||||
log_analytics_workspace_id: str,
|
||||
) -> Union[DiagnosticSettingsResource, Error]:
|
||||
logging.info("Setting up diagnostics for auto scale")
|
||||
client = get_monitor_client()
|
||||
|
||||
log_settings = LogSettings(
|
||||
enabled=True,
|
||||
category_group="allLogs",
|
||||
retention_policy=RetentionPolicy(enabled=True, days=30),
|
||||
)
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
"logs": [log_settings],
|
||||
"workspace_id": log_analytics_workspace_id,
|
||||
}
|
||||
|
||||
try:
|
||||
diagnostics = client.diagnostic_settings.create_or_update(
|
||||
auto_scale_resource_uri, "%s-diagnostics" % auto_scale_resource_name, params
|
||||
)
|
||||
logging.info(
|
||||
"Diagnostics created for auto scale resource: %s" % auto_scale_resource_uri
|
||||
)
|
||||
return diagnostics
|
||||
except (ResourceNotFoundError, CloudError):
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_CREATE,
|
||||
errors=[
|
||||
"unable to setup diagnostics for auto scale resource: %s"
|
||||
% (auto_scale_resource_uri)
|
||||
],
|
||||
)
|
@ -1,9 +0,0 @@
|
||||
from azure.mgmt.compute import ComputeManagementClient
|
||||
from memoization import cached
|
||||
|
||||
from .creds import get_identity, get_subscription
|
||||
|
||||
|
||||
@cached
|
||||
def get_compute_client() -> ComputeManagementClient:
|
||||
return ComputeManagementClient(get_identity(), get_subscription())
|
@ -1,388 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import urllib.parse
|
||||
from typing import Dict, Optional, Tuple, Union, cast
|
||||
|
||||
from azure.common import AzureHttpError, AzureMissingResourceHttpError
|
||||
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
|
||||
from azure.storage.blob import (
|
||||
BlobClient,
|
||||
BlobSasPermissions,
|
||||
BlobServiceClient,
|
||||
ContainerClient,
|
||||
ContainerSasPermissions,
|
||||
generate_blob_sas,
|
||||
generate_container_sas,
|
||||
)
|
||||
from memoization import cached
|
||||
from onefuzztypes.primitives import Container
|
||||
|
||||
from .storage import (
|
||||
StorageType,
|
||||
choose_account,
|
||||
get_accounts,
|
||||
get_storage_account_name_key,
|
||||
get_storage_account_name_key_by_name,
|
||||
)
|
||||
|
||||
CONTAINER_SAS_DEFAULT_DURATION = datetime.timedelta(days=30)
|
||||
|
||||
|
||||
def get_url(account_name: str) -> str:
|
||||
return f"https://{account_name}.blob.core.windows.net/"
|
||||
|
||||
|
||||
@cached
|
||||
def get_blob_service(account_id: str) -> BlobServiceClient:
|
||||
logging.debug("getting blob container (account_id: %s)", account_id)
|
||||
account_name, account_key = get_storage_account_name_key(account_id)
|
||||
account_url = get_url(account_name)
|
||||
service = BlobServiceClient(account_url=account_url, credential=account_key)
|
||||
return service
|
||||
|
||||
|
||||
def container_metadata(
|
||||
container: Container, account_id: str
|
||||
) -> Optional[Dict[str, str]]:
|
||||
try:
|
||||
result = (
|
||||
get_blob_service(account_id)
|
||||
.get_container_client(container)
|
||||
.get_container_properties()
|
||||
)
|
||||
return cast(Dict[str, str], result)
|
||||
except AzureHttpError:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def find_container(
|
||||
container: Container, storage_type: StorageType
|
||||
) -> Optional[ContainerClient]:
|
||||
accounts = get_accounts(storage_type)
|
||||
|
||||
# check secondary accounts first by searching in reverse.
|
||||
#
|
||||
# By implementation, the primary account is specified first, followed by
|
||||
# any secondary accounts.
|
||||
#
|
||||
# Secondary accounts, if they exist, are preferred for containers and have
|
||||
# increased IOP rates, this should be a slight optimization
|
||||
for account in reversed(accounts):
|
||||
client = get_blob_service(account).get_container_client(container)
|
||||
if client.exists():
|
||||
return client
|
||||
return None
|
||||
|
||||
|
||||
def container_exists(container: Container, storage_type: StorageType) -> bool:
|
||||
return find_container(container, storage_type) is not None
|
||||
|
||||
|
||||
def get_containers(storage_type: StorageType) -> Dict[str, Dict[str, str]]:
|
||||
containers: Dict[str, Dict[str, str]] = {}
|
||||
|
||||
for account_id in get_accounts(storage_type):
|
||||
containers.update(
|
||||
{
|
||||
x.name: x.metadata
|
||||
for x in get_blob_service(account_id).list_containers(
|
||||
include_metadata=True
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
return containers
|
||||
|
||||
|
||||
def get_container_metadata(
|
||||
container: Container, storage_type: StorageType
|
||||
) -> Optional[Dict[str, str]]:
|
||||
client = find_container(container, storage_type)
|
||||
if client is None:
|
||||
return None
|
||||
result = client.get_container_properties().metadata
|
||||
return cast(Dict[str, str], result)
|
||||
|
||||
|
||||
def add_container_sas_url(
|
||||
container_url: str, duration: datetime.timedelta = CONTAINER_SAS_DEFAULT_DURATION
|
||||
) -> str:
|
||||
parsed = urllib.parse.urlparse(container_url)
|
||||
query = urllib.parse.parse_qs(parsed.query)
|
||||
if "sig" in query:
|
||||
return container_url
|
||||
else:
|
||||
start, expiry = sas_time_window(duration)
|
||||
account_name = parsed.netloc.split(".")[0]
|
||||
account_key = get_storage_account_name_key_by_name(account_name)
|
||||
sas_token = generate_container_sas(
|
||||
account_name=account_name,
|
||||
container_name=parsed.path.split("/")[1],
|
||||
account_key=account_key,
|
||||
permission=ContainerSasPermissions(
|
||||
read=True, write=True, delete=True, list=True
|
||||
),
|
||||
expiry=expiry,
|
||||
start=start,
|
||||
)
|
||||
return f"{container_url}?{sas_token}"
|
||||
|
||||
|
||||
def get_or_create_container_client(
|
||||
container: Container,
|
||||
storage_type: StorageType,
|
||||
metadata: Optional[Dict[str, str]],
|
||||
) -> Optional[ContainerClient]:
|
||||
client = find_container(container, storage_type)
|
||||
if client is None:
|
||||
account = choose_account(storage_type)
|
||||
client = get_blob_service(account).get_container_client(container)
|
||||
try:
|
||||
client.create_container(metadata=metadata)
|
||||
except (ResourceExistsError, AzureHttpError) as err:
|
||||
# note: resource exists error happens during creation if the container
|
||||
# is being deleted
|
||||
logging.error(
|
||||
(
|
||||
"unable to create container. account: %s "
|
||||
"container: %s metadata: %s - %s"
|
||||
),
|
||||
account,
|
||||
container,
|
||||
metadata,
|
||||
err,
|
||||
)
|
||||
return None
|
||||
return client
|
||||
|
||||
|
||||
def create_container(
|
||||
container: Container,
|
||||
storage_type: StorageType,
|
||||
metadata: Optional[Dict[str, str]],
|
||||
) -> Optional[str]:
|
||||
client = get_or_create_container_client(container, storage_type, metadata)
|
||||
if client is None:
|
||||
return None
|
||||
return get_container_sas_url_service(
|
||||
client,
|
||||
read=True,
|
||||
write=True,
|
||||
delete=True,
|
||||
list_=True,
|
||||
)
|
||||
|
||||
|
||||
def delete_container(container: Container, storage_type: StorageType) -> bool:
|
||||
accounts = get_accounts(storage_type)
|
||||
deleted = False
|
||||
for account in accounts:
|
||||
service = get_blob_service(account)
|
||||
try:
|
||||
service.delete_container(container)
|
||||
deleted = True
|
||||
except ResourceNotFoundError:
|
||||
pass
|
||||
|
||||
return deleted
|
||||
|
||||
|
||||
def sas_time_window(
|
||||
duration: datetime.timedelta,
|
||||
) -> Tuple[datetime.datetime, datetime.datetime]:
|
||||
# SAS URLs are valid 6 hours earlier, primarily to work around dev
|
||||
# workstations having out-of-sync time. Additionally, SAS URLs are stopped
|
||||
# 15 minutes later than requested based on "Be careful with SAS start time"
|
||||
# guidance.
|
||||
# Ref: https://docs.microsoft.com/en-us/azure/storage/common/storage-sas-overview
|
||||
SAS_START_TIME_DELTA = datetime.timedelta(hours=6)
|
||||
SAS_END_TIME_DELTA = datetime.timedelta(minutes=15)
|
||||
|
||||
now = datetime.datetime.utcnow()
|
||||
start = now - SAS_START_TIME_DELTA
|
||||
expiry = now + duration + SAS_END_TIME_DELTA
|
||||
return (start, expiry)
|
||||
|
||||
|
||||
def get_container_sas_url_service(
|
||||
client: ContainerClient,
|
||||
*,
|
||||
read: bool = False,
|
||||
write: bool = False,
|
||||
delete: bool = False,
|
||||
list_: bool = False,
|
||||
delete_previous_version: bool = False,
|
||||
tag: bool = False,
|
||||
duration: datetime.timedelta = CONTAINER_SAS_DEFAULT_DURATION,
|
||||
) -> str:
|
||||
account_name = client.account_name
|
||||
container_name = client.container_name
|
||||
account_key = get_storage_account_name_key_by_name(account_name)
|
||||
|
||||
start, expiry = sas_time_window(duration)
|
||||
|
||||
sas = generate_container_sas(
|
||||
account_name,
|
||||
container_name,
|
||||
account_key=account_key,
|
||||
permission=ContainerSasPermissions(
|
||||
read=read,
|
||||
write=write,
|
||||
delete=delete,
|
||||
list=list_,
|
||||
delete_previous_version=delete_previous_version,
|
||||
tag=tag,
|
||||
),
|
||||
start=start,
|
||||
expiry=expiry,
|
||||
)
|
||||
|
||||
with_sas = ContainerClient(
|
||||
get_url(account_name),
|
||||
container_name=container_name,
|
||||
credential=sas,
|
||||
)
|
||||
return cast(str, with_sas.url)
|
||||
|
||||
|
||||
def get_container_sas_url(
|
||||
container: Container,
|
||||
storage_type: StorageType,
|
||||
*,
|
||||
read: bool = False,
|
||||
write: bool = False,
|
||||
delete: bool = False,
|
||||
list_: bool = False,
|
||||
) -> str:
|
||||
client = find_container(container, storage_type)
|
||||
if not client:
|
||||
raise Exception("unable to create container sas for missing container")
|
||||
|
||||
return get_container_sas_url_service(
|
||||
client,
|
||||
read=read,
|
||||
write=write,
|
||||
delete=delete,
|
||||
list_=list_,
|
||||
)
|
||||
|
||||
|
||||
def get_file_url(container: Container, name: str, storage_type: StorageType) -> str:
|
||||
client = find_container(container, storage_type)
|
||||
if not client:
|
||||
raise Exception("unable to find container: %s - %s" % (container, storage_type))
|
||||
|
||||
# get_url has a trailing '/'
|
||||
return f"{get_url(client.account_name)}{container}/{name}"
|
||||
|
||||
|
||||
def get_file_sas_url(
|
||||
container: Container,
|
||||
name: str,
|
||||
storage_type: StorageType,
|
||||
*,
|
||||
read: bool = False,
|
||||
add: bool = False,
|
||||
create: bool = False,
|
||||
write: bool = False,
|
||||
delete: bool = False,
|
||||
delete_previous_version: bool = False,
|
||||
tag: bool = False,
|
||||
duration: datetime.timedelta = CONTAINER_SAS_DEFAULT_DURATION,
|
||||
) -> str:
|
||||
client = find_container(container, storage_type)
|
||||
if not client:
|
||||
raise Exception("unable to find container: %s - %s" % (container, storage_type))
|
||||
|
||||
account_key = get_storage_account_name_key_by_name(client.account_name)
|
||||
start, expiry = sas_time_window(duration)
|
||||
|
||||
permission = BlobSasPermissions(
|
||||
read=read,
|
||||
add=add,
|
||||
create=create,
|
||||
write=write,
|
||||
delete=delete,
|
||||
delete_previous_version=delete_previous_version,
|
||||
tag=tag,
|
||||
)
|
||||
sas = generate_blob_sas(
|
||||
client.account_name,
|
||||
container,
|
||||
name,
|
||||
account_key=account_key,
|
||||
permission=permission,
|
||||
expiry=expiry,
|
||||
start=start,
|
||||
)
|
||||
|
||||
with_sas = BlobClient(
|
||||
get_url(client.account_name),
|
||||
container,
|
||||
name,
|
||||
credential=sas,
|
||||
)
|
||||
return cast(str, with_sas.url)
|
||||
|
||||
|
||||
def save_blob(
|
||||
container: Container,
|
||||
name: str,
|
||||
data: Union[str, bytes],
|
||||
storage_type: StorageType,
|
||||
) -> None:
|
||||
client = find_container(container, storage_type)
|
||||
if not client:
|
||||
raise Exception("unable to find container: %s - %s" % (container, storage_type))
|
||||
|
||||
client.get_blob_client(name).upload_blob(data, overwrite=True)
|
||||
|
||||
|
||||
def get_blob(
|
||||
container: Container, name: str, storage_type: StorageType
|
||||
) -> Optional[bytes]:
|
||||
client = find_container(container, storage_type)
|
||||
if not client:
|
||||
return None
|
||||
|
||||
try:
|
||||
return cast(
|
||||
bytes, client.get_blob_client(name).download_blob().content_as_bytes()
|
||||
)
|
||||
except AzureMissingResourceHttpError:
|
||||
return None
|
||||
|
||||
|
||||
def blob_exists(container: Container, name: str, storage_type: StorageType) -> bool:
|
||||
client = find_container(container, storage_type)
|
||||
if not client:
|
||||
return False
|
||||
|
||||
return cast(bool, client.get_blob_client(name).exists())
|
||||
|
||||
|
||||
def delete_blob(container: Container, name: str, storage_type: StorageType) -> bool:
|
||||
client = find_container(container, storage_type)
|
||||
if not client:
|
||||
return False
|
||||
|
||||
try:
|
||||
client.get_blob_client(name).delete_blob()
|
||||
return True
|
||||
except AzureMissingResourceHttpError:
|
||||
return False
|
||||
|
||||
|
||||
def auth_download_url(container: Container, filename: str) -> str:
|
||||
instance = os.environ["ONEFUZZ_INSTANCE"]
|
||||
return "%s/api/download?%s" % (
|
||||
instance,
|
||||
urllib.parse.urlencode({"container": container, "filename": filename}),
|
||||
)
|
@ -1,246 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import urllib.parse
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar, cast
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
from azure.core.exceptions import ClientAuthenticationError
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from azure.keyvault.secrets import SecretClient
|
||||
from azure.mgmt.resource import ResourceManagementClient
|
||||
from azure.mgmt.subscription import SubscriptionClient
|
||||
from memoization import cached
|
||||
from msrestazure.azure_active_directory import MSIAuthentication
|
||||
from msrestazure.tools import parse_resource_id
|
||||
from onefuzztypes.primitives import Container, Region
|
||||
|
||||
from .monkeypatch import allow_more_workers, reduce_logging
|
||||
|
||||
# https://docs.microsoft.com/en-us/graph/api/overview?view=graph-rest-1.0
|
||||
GRAPH_RESOURCE = "https://graph.microsoft.com"
|
||||
GRAPH_RESOURCE_ENDPOINT = "https://graph.microsoft.com/v1.0"
|
||||
|
||||
|
||||
@cached
|
||||
def get_msi() -> MSIAuthentication:
|
||||
allow_more_workers()
|
||||
reduce_logging()
|
||||
return MSIAuthentication()
|
||||
|
||||
|
||||
@cached
|
||||
def get_identity() -> DefaultAzureCredential:
|
||||
allow_more_workers()
|
||||
reduce_logging()
|
||||
return DefaultAzureCredential()
|
||||
|
||||
|
||||
@cached
|
||||
def get_base_resource_group() -> Any: # should be str
|
||||
return parse_resource_id(os.environ["ONEFUZZ_RESOURCE_GROUP"])["resource_group"]
|
||||
|
||||
|
||||
@cached
|
||||
def get_base_region() -> Region:
|
||||
client = ResourceManagementClient(
|
||||
credential=get_identity(), subscription_id=get_subscription()
|
||||
)
|
||||
group = client.resource_groups.get(get_base_resource_group())
|
||||
return Region(group.location)
|
||||
|
||||
|
||||
@cached
|
||||
def get_subscription() -> Any: # should be str
|
||||
return parse_resource_id(os.environ["ONEFUZZ_DATA_STORAGE"])["subscription"]
|
||||
|
||||
|
||||
@cached
|
||||
def get_insights_instrumentation_key() -> Any: # should be str
|
||||
return os.environ["APPINSIGHTS_INSTRUMENTATIONKEY"]
|
||||
|
||||
|
||||
@cached
|
||||
def get_insights_appid() -> str:
|
||||
return os.environ["APPINSIGHTS_APPID"]
|
||||
|
||||
|
||||
@cached
|
||||
def get_instance_name() -> str:
|
||||
return os.environ["ONEFUZZ_INSTANCE_NAME"]
|
||||
|
||||
|
||||
@cached
|
||||
def get_instance_url() -> str:
|
||||
return "https://%s.azurewebsites.net" % get_instance_name()
|
||||
|
||||
|
||||
@cached
|
||||
def get_agent_instance_url() -> str:
|
||||
return get_instance_url()
|
||||
|
||||
|
||||
@cached
|
||||
def get_instance_id() -> UUID:
|
||||
from .containers import get_blob
|
||||
from .storage import StorageType
|
||||
|
||||
blob = get_blob(Container("base-config"), "instance_id", StorageType.config)
|
||||
if blob is None:
|
||||
raise Exception("missing instance_id")
|
||||
return UUID(blob.decode())
|
||||
|
||||
|
||||
DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
|
||||
@cached(ttl=DAY_IN_SECONDS)
|
||||
def get_regions() -> List[Region]:
|
||||
subscription = get_subscription()
|
||||
client = SubscriptionClient(credential=get_identity())
|
||||
locations = client.subscriptions.list_locations(subscription)
|
||||
return sorted([Region(x.name) for x in locations])
|
||||
|
||||
|
||||
class GraphQueryError(Exception):
|
||||
def __init__(self, message: str, status_code: Optional[int]) -> None:
|
||||
super(GraphQueryError, self).__init__(message)
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
def query_microsoft_graph(
|
||||
method: str,
|
||||
resource: str,
|
||||
params: Optional[Dict] = None,
|
||||
body: Optional[Dict] = None,
|
||||
) -> Dict:
|
||||
cred = get_identity()
|
||||
access_token = cred.get_token(f"{GRAPH_RESOURCE}/.default")
|
||||
|
||||
url = urllib.parse.urljoin(f"{GRAPH_RESOURCE_ENDPOINT}/", resource)
|
||||
headers = {
|
||||
"Authorization": "Bearer %s" % access_token.token,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
response = requests.request(
|
||||
method=method, url=url, headers=headers, params=params, json=body
|
||||
)
|
||||
|
||||
if 200 <= response.status_code < 300:
|
||||
if response.content and response.content.strip():
|
||||
json = response.json()
|
||||
if isinstance(json, Dict):
|
||||
return json
|
||||
else:
|
||||
raise GraphQueryError(
|
||||
"invalid data expected a json object: HTTP"
|
||||
f" {response.status_code} - {json}",
|
||||
response.status_code,
|
||||
)
|
||||
else:
|
||||
return {}
|
||||
else:
|
||||
error_text = str(response.content, encoding="utf-8", errors="backslashreplace")
|
||||
raise GraphQueryError(
|
||||
f"request did not succeed: HTTP {response.status_code} - {error_text}",
|
||||
response.status_code,
|
||||
)
|
||||
|
||||
|
||||
def query_microsoft_graph_list(
|
||||
method: str,
|
||||
resource: str,
|
||||
params: Optional[Dict] = None,
|
||||
body: Optional[Dict] = None,
|
||||
) -> List[Any]:
|
||||
result = query_microsoft_graph(
|
||||
method,
|
||||
resource,
|
||||
params,
|
||||
body,
|
||||
)
|
||||
value = result.get("value")
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
else:
|
||||
raise GraphQueryError("Expected data containing a list of values", None)
|
||||
|
||||
|
||||
@cached
|
||||
def get_scaleset_identity_resource_path() -> str:
|
||||
scaleset_id_name = "%s-scalesetid" % get_instance_name()
|
||||
resource_group_path = "/subscriptions/%s/resourceGroups/%s/providers" % (
|
||||
get_subscription(),
|
||||
get_base_resource_group(),
|
||||
)
|
||||
return "%s/Microsoft.ManagedIdentity/userAssignedIdentities/%s" % (
|
||||
resource_group_path,
|
||||
scaleset_id_name,
|
||||
)
|
||||
|
||||
|
||||
@cached
|
||||
def get_scaleset_principal_id() -> UUID:
|
||||
api_version = "2018-11-30" # matches the apiversion in the deployment template
|
||||
client = ResourceManagementClient(
|
||||
credential=get_identity(), subscription_id=get_subscription()
|
||||
)
|
||||
uid = client.resources.get_by_id(get_scaleset_identity_resource_path(), api_version)
|
||||
return UUID(uid.properties["principalId"])
|
||||
|
||||
|
||||
@cached
|
||||
def get_keyvault_client(vault_url: str) -> SecretClient:
|
||||
return SecretClient(vault_url=vault_url, credential=DefaultAzureCredential())
|
||||
|
||||
|
||||
def clear_azure_client_cache() -> None:
|
||||
# clears the memoization of the Azure clients.
|
||||
|
||||
from .compute import get_compute_client
|
||||
from .containers import get_blob_service
|
||||
from .network_mgmt_client import get_network_client
|
||||
from .storage import get_mgmt_client
|
||||
|
||||
# currently memoization.cache does not project the wrapped function's types.
|
||||
# As a workaround, CI comments out the `cached` wrapper, then runs the type
|
||||
# validation. This enables calling the wrapper's clear_cache if it's not
|
||||
# disabled.
|
||||
for func in [
|
||||
get_msi,
|
||||
get_identity,
|
||||
get_compute_client,
|
||||
get_blob_service,
|
||||
get_network_client,
|
||||
get_mgmt_client,
|
||||
]:
|
||||
clear_func = getattr(func, "clear_cache", None)
|
||||
if clear_func is not None:
|
||||
clear_func()
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Callable[..., Any])
|
||||
|
||||
|
||||
class retry_on_auth_failure:
|
||||
def __call__(self, func: T) -> T:
|
||||
@functools.wraps(func)
|
||||
def decorated(*args, **kwargs): # type: ignore
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except ClientAuthenticationError as err:
|
||||
logging.warning(
|
||||
"clearing authentication cache after auth failure: %s", err
|
||||
)
|
||||
|
||||
clear_azure_client_cache()
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return cast(T, decorated)
|
@ -1,29 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from msrestazure.azure_exceptions import CloudError
|
||||
|
||||
from .compute import get_compute_client
|
||||
|
||||
|
||||
def list_disks(resource_group: str) -> Any:
|
||||
logging.info("listing disks %s", resource_group)
|
||||
compute_client = get_compute_client()
|
||||
return compute_client.disks.list_by_resource_group(resource_group)
|
||||
|
||||
|
||||
def delete_disk(resource_group: str, name: str) -> bool:
|
||||
logging.info("deleting disks %s : %s", resource_group, name)
|
||||
compute_client = get_compute_client()
|
||||
try:
|
||||
compute_client.disks.begin_delete(resource_group, name)
|
||||
return True
|
||||
except (ResourceNotFoundError, CloudError) as err:
|
||||
logging.error("unable to delete disk: %s", err)
|
||||
return False
|
@ -1,51 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Dict, List, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from ..config import InstanceConfig
|
||||
from .creds import query_microsoft_graph_list
|
||||
|
||||
|
||||
class GroupMembershipChecker(Protocol):
|
||||
def is_member(self, group_ids: List[UUID], member_id: UUID) -> bool:
|
||||
"""Check if member is part of at least one of the groups"""
|
||||
if member_id in group_ids:
|
||||
return True
|
||||
|
||||
groups = self.get_groups(member_id)
|
||||
for g in group_ids:
|
||||
if g in groups:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_groups(self, member_id: UUID) -> List[UUID]:
|
||||
"""Gets all the groups of the provided member"""
|
||||
|
||||
|
||||
def create_group_membership_checker() -> GroupMembershipChecker:
|
||||
config = InstanceConfig.fetch()
|
||||
if config.group_membership:
|
||||
return StaticGroupMembership(config.group_membership)
|
||||
else:
|
||||
return AzureADGroupMembership()
|
||||
|
||||
|
||||
class AzureADGroupMembership(GroupMembershipChecker):
|
||||
def get_groups(self, member_id: UUID) -> List[UUID]:
|
||||
response = query_microsoft_graph_list(
|
||||
method="GET", resource=f"users/{member_id}/transitiveMemberOf"
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
class StaticGroupMembership(GroupMembershipChecker):
|
||||
def __init__(self, memberships: Dict[str, List[UUID]]):
|
||||
self.memberships = memberships
|
||||
|
||||
def get_groups(self, member_id: UUID) -> List[UUID]:
|
||||
return self.memberships.get(str(member_id), [])
|
@ -1,55 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Union
|
||||
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from memoization import cached
|
||||
from msrestazure.azure_exceptions import CloudError
|
||||
from msrestazure.tools import parse_resource_id
|
||||
from onefuzztypes.enums import OS, ErrorCode
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.primitives import Region
|
||||
|
||||
from .compute import get_compute_client
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_os(region: Region, image: str) -> Union[Error, OS]:
|
||||
client = get_compute_client()
|
||||
# The dict returned here may not have any defined keys.
|
||||
#
|
||||
# See: https://github.com/Azure/msrestazure-for-python/blob/v0.6.3/msrestazure/tools.py#L134 # noqa: E501
|
||||
parsed = parse_resource_id(image)
|
||||
if "resource_group" in parsed:
|
||||
if parsed["type"] == "galleries":
|
||||
try:
|
||||
# See: https://docs.microsoft.com/en-us/rest/api/compute/gallery-images/get#galleryimage # noqa: E501
|
||||
name = client.gallery_images.get(
|
||||
parsed["resource_group"], parsed["name"], parsed["child_name_1"]
|
||||
).os_type.lower()
|
||||
except (ResourceNotFoundError, CloudError) as err:
|
||||
return Error(code=ErrorCode.INVALID_IMAGE, errors=[str(err)])
|
||||
else:
|
||||
try:
|
||||
# See: https://docs.microsoft.com/en-us/rest/api/compute/images/get
|
||||
name = client.images.get(
|
||||
parsed["resource_group"], parsed["name"]
|
||||
).storage_profile.os_disk.os_type.lower()
|
||||
except (ResourceNotFoundError, CloudError) as err:
|
||||
return Error(code=ErrorCode.INVALID_IMAGE, errors=[str(err)])
|
||||
else:
|
||||
publisher, offer, sku, version = image.split(":")
|
||||
try:
|
||||
if version == "latest":
|
||||
version = client.virtual_machine_images.list(
|
||||
region, publisher, offer, sku, top=1
|
||||
)[0].name
|
||||
name = client.virtual_machine_images.get(
|
||||
region, publisher, offer, sku, version
|
||||
).os_disk_image.operating_system.lower()
|
||||
except (ResourceNotFoundError, CloudError) as err:
|
||||
return Error(code=ErrorCode.INVALID_IMAGE, errors=[str(err)])
|
||||
return OS[name]
|
@ -1,168 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from azure.mgmt.network.models import Subnet
|
||||
from msrestazure.azure_exceptions import CloudError
|
||||
from msrestazure.tools import parse_resource_id
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.primitives import Region
|
||||
|
||||
from .creds import get_base_resource_group
|
||||
from .network import Network
|
||||
from .network_mgmt_client import get_network_client
|
||||
from .nsg import NSG
|
||||
from .vmss import get_instance_id
|
||||
|
||||
|
||||
def get_scaleset_instance_ip(scaleset: UUID, machine_id: UUID) -> Optional[str]:
|
||||
instance = get_instance_id(scaleset, machine_id)
|
||||
if not isinstance(instance, str):
|
||||
return None
|
||||
|
||||
resource_group = get_base_resource_group()
|
||||
|
||||
client = get_network_client()
|
||||
intf = client.network_interfaces.list_virtual_machine_scale_set_network_interfaces(
|
||||
resource_group, str(scaleset)
|
||||
)
|
||||
try:
|
||||
for interface in intf:
|
||||
resource = parse_resource_id(interface.virtual_machine.id)
|
||||
if resource.get("resource_name") != instance:
|
||||
continue
|
||||
|
||||
for config in interface.ip_configurations:
|
||||
if config.private_ip_address is None:
|
||||
continue
|
||||
return str(config.private_ip_address)
|
||||
except (ResourceNotFoundError, CloudError):
|
||||
# this can fail if an interface is removed during the iteration
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_ip(resource_group: str, name: str) -> Optional[Any]:
|
||||
logging.info("getting ip %s:%s", resource_group, name)
|
||||
network_client = get_network_client()
|
||||
try:
|
||||
return network_client.public_ip_addresses.get(resource_group, name)
|
||||
except (ResourceNotFoundError, CloudError):
|
||||
return None
|
||||
|
||||
|
||||
def delete_ip(resource_group: str, name: str) -> Any:
|
||||
logging.info("deleting ip %s:%s", resource_group, name)
|
||||
network_client = get_network_client()
|
||||
return network_client.public_ip_addresses.begin_delete(resource_group, name)
|
||||
|
||||
|
||||
def create_ip(resource_group: str, name: str, region: Region) -> Any:
|
||||
logging.info("creating ip for %s:%s in %s", resource_group, name, region)
|
||||
|
||||
network_client = get_network_client()
|
||||
params: Dict[str, Union[str, Dict[str, str]]] = {
|
||||
"location": region,
|
||||
"public_ip_allocation_method": "Dynamic",
|
||||
}
|
||||
if "ONEFUZZ_OWNER" in os.environ:
|
||||
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
|
||||
return network_client.public_ip_addresses.begin_create_or_update(
|
||||
resource_group, name, params
|
||||
)
|
||||
|
||||
|
||||
def get_public_nic(resource_group: str, name: str) -> Optional[Any]:
|
||||
logging.info("getting nic: %s %s", resource_group, name)
|
||||
network_client = get_network_client()
|
||||
try:
|
||||
return network_client.network_interfaces.get(resource_group, name)
|
||||
except (ResourceNotFoundError, CloudError):
|
||||
return None
|
||||
|
||||
|
||||
def delete_nic(resource_group: str, name: str) -> Optional[Any]:
|
||||
logging.info("deleting nic %s:%s", resource_group, name)
|
||||
network_client = get_network_client()
|
||||
return network_client.network_interfaces.begin_delete(resource_group, name)
|
||||
|
||||
|
||||
def create_public_nic(
|
||||
resource_group: str, name: str, region: Region, nsg: Optional[NSG]
|
||||
) -> Optional[Error]:
|
||||
logging.info("creating nic for %s:%s in %s", resource_group, name, region)
|
||||
|
||||
network = Network(region)
|
||||
subnet_id = network.get_id()
|
||||
if subnet_id is None:
|
||||
network.create()
|
||||
return None
|
||||
|
||||
if nsg:
|
||||
subnet = network.get_subnet()
|
||||
if isinstance(subnet, Subnet) and not subnet.network_security_group:
|
||||
result = nsg.associate_subnet(network.get_vnet(), subnet)
|
||||
if isinstance(result, Error):
|
||||
return result
|
||||
return None
|
||||
|
||||
ip = get_ip(resource_group, name)
|
||||
if not ip:
|
||||
create_ip(resource_group, name, region)
|
||||
return None
|
||||
|
||||
params = {
|
||||
"location": region,
|
||||
"ip_configurations": [
|
||||
{
|
||||
"name": "myIPConfig",
|
||||
"public_ip_address": ip,
|
||||
"subnet": {"id": subnet_id},
|
||||
}
|
||||
],
|
||||
}
|
||||
if "ONEFUZZ_OWNER" in os.environ:
|
||||
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
|
||||
|
||||
network_client = get_network_client()
|
||||
try:
|
||||
network_client.network_interfaces.begin_create_or_update(
|
||||
resource_group, name, params
|
||||
)
|
||||
except (ResourceNotFoundError, CloudError) as err:
|
||||
if "RetryableError" not in repr(err):
|
||||
return Error(
|
||||
code=ErrorCode.VM_CREATE_FAILED,
|
||||
errors=["unable to create nic: %s" % err],
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def get_public_ip(resource_id: str) -> Optional[str]:
|
||||
logging.info("getting ip for %s", resource_id)
|
||||
network_client = get_network_client()
|
||||
resource = parse_resource_id(resource_id)
|
||||
ip = (
|
||||
network_client.network_interfaces.get(
|
||||
resource["resource_group"], resource["name"]
|
||||
)
|
||||
.ip_configurations[0]
|
||||
.public_ip_address
|
||||
)
|
||||
resource = parse_resource_id(ip.id)
|
||||
ip = network_client.public_ip_addresses.get(
|
||||
resource["resource_group"], resource["name"]
|
||||
).ip_address
|
||||
if ip is None:
|
||||
return None
|
||||
else:
|
||||
return str(ip)
|
@ -1,43 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from azure.mgmt.loganalytics import LogAnalyticsManagementClient
|
||||
from memoization import cached
|
||||
|
||||
from .creds import get_base_resource_group, get_identity, get_subscription
|
||||
|
||||
|
||||
@cached
|
||||
def get_monitor_client() -> LogAnalyticsManagementClient:
|
||||
return LogAnalyticsManagementClient(get_identity(), get_subscription())
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_monitor_settings() -> Dict[str, str]:
|
||||
resource_group = get_base_resource_group()
|
||||
workspace_name = os.environ["ONEFUZZ_MONITOR"]
|
||||
client = get_monitor_client()
|
||||
customer_id = client.workspaces.get(resource_group, workspace_name).customer_id
|
||||
shared_key = client.shared_keys.get_shared_keys(
|
||||
resource_group, workspace_name
|
||||
).primary_shared_key
|
||||
return {"id": customer_id, "key": shared_key}
|
||||
|
||||
|
||||
def get_workspace_id() -> str:
|
||||
# TODO:
|
||||
# Once #1679 merges, we can use ONEFUZZ_MONITOR instead of ONEFUZZ_INSTANCE_NAME
|
||||
workspace_id = (
|
||||
"/subscriptions/%s/resourceGroups/%s/providers/microsoft.operationalinsights/workspaces/%s" # noqa: E501
|
||||
% (
|
||||
get_subscription(),
|
||||
get_base_resource_group(),
|
||||
os.environ["ONEFUZZ_INSTANCE_NAME"],
|
||||
)
|
||||
)
|
||||
return workspace_id
|
@ -1,9 +0,0 @@
|
||||
from azure.mgmt.monitor import MonitorManagementClient
|
||||
from memoization import cached
|
||||
|
||||
from .creds import get_identity, get_subscription
|
||||
|
||||
|
||||
@cached
|
||||
def get_monitor_client() -> MonitorManagementClient:
|
||||
return MonitorManagementClient(get_identity(), get_subscription())
|
@ -1,56 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
WORKERS_DONE = False
|
||||
REDUCE_LOGGING = False
|
||||
|
||||
|
||||
def allow_more_workers() -> None:
|
||||
global WORKERS_DONE
|
||||
if WORKERS_DONE:
|
||||
return
|
||||
|
||||
stack = inspect.stack()
|
||||
for entry in stack:
|
||||
if entry.filename.endswith("azure_functions_worker/dispatcher.py"):
|
||||
if entry.frame.f_locals["self"]._sync_call_tp._max_workers == 1:
|
||||
logging.info("bumped thread worker count to 50")
|
||||
entry.frame.f_locals["self"]._sync_call_tp._max_workers = 50
|
||||
|
||||
WORKERS_DONE = True
|
||||
|
||||
|
||||
# TODO: Replace this with a better method for filtering out logging
|
||||
# See https://github.com/Azure/azure-functions-python-worker/issues/743
|
||||
def reduce_logging() -> None:
|
||||
global REDUCE_LOGGING
|
||||
if REDUCE_LOGGING:
|
||||
return
|
||||
|
||||
to_quiet = [
|
||||
"azure",
|
||||
"cli",
|
||||
"grpc",
|
||||
"concurrent",
|
||||
"oauthlib",
|
||||
"msrest",
|
||||
"opencensus",
|
||||
"urllib3",
|
||||
"requests",
|
||||
"aiohttp",
|
||||
"asyncio",
|
||||
"adal-python",
|
||||
]
|
||||
|
||||
for name in logging.Logger.manager.loggerDict:
|
||||
logger = logging.getLogger(name)
|
||||
for prefix in to_quiet:
|
||||
if logger.name.startswith(prefix):
|
||||
logger.level = logging.WARN
|
||||
|
||||
REDUCE_LOGGING = True
|
@ -1,78 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Optional, Union
|
||||
|
||||
from azure.mgmt.network.models import Subnet, VirtualNetwork
|
||||
from msrestazure.azure_exceptions import CloudError
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error, NetworkConfig
|
||||
from onefuzztypes.primitives import Region
|
||||
|
||||
from ..config import InstanceConfig
|
||||
from .creds import get_base_resource_group
|
||||
from .subnet import create_virtual_network, get_subnet, get_subnet_id, get_vnet
|
||||
|
||||
# This was generated randomly and should be preserved moving forwards
|
||||
NETWORK_GUID_NAMESPACE = uuid.UUID("372977ad-b533-416a-b1b4-f770898e0b11")
|
||||
|
||||
|
||||
class Network:
|
||||
def __init__(self, region: Region):
|
||||
self.group = get_base_resource_group()
|
||||
self.region = region
|
||||
self.network_config = InstanceConfig.fetch().network_config
|
||||
|
||||
# Network names will be calculated from the address_space/subnet
|
||||
# *except* if they are the original values. This allows backwards
|
||||
# compatibility to existing configs if you don't change the network
|
||||
# configs.
|
||||
if (
|
||||
self.network_config.address_space
|
||||
== NetworkConfig.__fields__["address_space"].default
|
||||
and self.network_config.subnet == NetworkConfig.__fields__["subnet"].default
|
||||
):
|
||||
self.name: str = self.region
|
||||
else:
|
||||
network_id = uuid.uuid5(
|
||||
NETWORK_GUID_NAMESPACE,
|
||||
"|".join(
|
||||
[self.network_config.address_space, self.network_config.subnet]
|
||||
),
|
||||
)
|
||||
self.name = f"{self.region}-{network_id}"
|
||||
|
||||
def exists(self) -> bool:
|
||||
return self.get_id() is not None
|
||||
|
||||
def get_id(self) -> Optional[str]:
|
||||
return get_subnet_id(self.group, self.name, self.name)
|
||||
|
||||
def get_subnet(self) -> Optional[Subnet]:
|
||||
return get_subnet(self.group, self.name, self.name)
|
||||
|
||||
def get_vnet(self) -> Optional[VirtualNetwork]:
|
||||
return get_vnet(self.group, self.name)
|
||||
|
||||
def create(self) -> Union[None, Error]:
|
||||
if not self.exists():
|
||||
result = create_virtual_network(
|
||||
self.group, self.name, self.region, self.network_config
|
||||
)
|
||||
if isinstance(result, CloudError):
|
||||
error = Error(
|
||||
code=ErrorCode.UNABLE_TO_CREATE_NETWORK, errors=[result.message]
|
||||
)
|
||||
logging.error(
|
||||
"network creation failed: %s:%s- %s",
|
||||
self.name,
|
||||
self.region,
|
||||
error,
|
||||
)
|
||||
return error
|
||||
|
||||
return None
|
@ -1,9 +0,0 @@
|
||||
from azure.mgmt.network import NetworkManagementClient
|
||||
from memoization import cached
|
||||
|
||||
from .creds import get_identity, get_subscription
|
||||
|
||||
|
||||
@cached
|
||||
def get_network_client() -> NetworkManagementClient:
|
||||
return NetworkManagementClient(get_identity(), get_subscription())
|
@ -1,489 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, List, Optional, Set, Union, cast
|
||||
|
||||
from azure.core.exceptions import HttpResponseError, ResourceNotFoundError
|
||||
from azure.mgmt.network.models import (
|
||||
NetworkInterface,
|
||||
NetworkSecurityGroup,
|
||||
SecurityRule,
|
||||
SecurityRuleAccess,
|
||||
Subnet,
|
||||
VirtualNetwork,
|
||||
)
|
||||
from msrestazure.azure_exceptions import CloudError
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error, NetworkSecurityGroupConfig
|
||||
from onefuzztypes.primitives import Region
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from .creds import get_base_resource_group
|
||||
from .network_mgmt_client import get_network_client
|
||||
|
||||
|
||||
def is_concurrent_request_error(err: str) -> bool:
|
||||
return "The request failed due to conflict with a concurrent request" in str(err)
|
||||
|
||||
|
||||
def get_nsg(name: str) -> Optional[NetworkSecurityGroup]:
|
||||
resource_group = get_base_resource_group()
|
||||
|
||||
logging.debug("getting nsg: %s", name)
|
||||
network_client = get_network_client()
|
||||
try:
|
||||
nsg = network_client.network_security_groups.get(resource_group, name)
|
||||
return cast(NetworkSecurityGroup, nsg)
|
||||
except (ResourceNotFoundError, CloudError) as err:
|
||||
logging.debug("nsg %s does not exist: %s", name, err)
|
||||
return None
|
||||
|
||||
|
||||
def create_nsg(name: str, location: Region) -> Union[None, Error]:
|
||||
resource_group = get_base_resource_group()
|
||||
|
||||
logging.info("creating nsg %s:%s:%s", resource_group, location, name)
|
||||
network_client = get_network_client()
|
||||
|
||||
params: Dict = {
|
||||
"location": location,
|
||||
}
|
||||
|
||||
if "ONEFUZZ_OWNER" in os.environ:
|
||||
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
|
||||
|
||||
try:
|
||||
network_client.network_security_groups.begin_create_or_update(
|
||||
resource_group, name, params
|
||||
)
|
||||
except (ResourceNotFoundError, CloudError) as err:
|
||||
if is_concurrent_request_error(str(err)):
|
||||
logging.debug(
|
||||
"create NSG had conflicts with concurrent request, ignoring %s", err
|
||||
)
|
||||
return None
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_CREATE,
|
||||
errors=["Unable to create nsg %s due to %s" % (name, err)],
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def list_nsgs() -> List[NetworkSecurityGroup]:
|
||||
resource_group = get_base_resource_group()
|
||||
network_client = get_network_client()
|
||||
return list(network_client.network_security_groups.list(resource_group))
|
||||
|
||||
|
||||
def update_nsg(nsg: NetworkSecurityGroup) -> Union[None, Error]:
|
||||
resource_group = get_base_resource_group()
|
||||
|
||||
logging.info("updating nsg %s:%s:%s", resource_group, nsg.location, nsg.name)
|
||||
network_client = get_network_client()
|
||||
|
||||
try:
|
||||
network_client.network_security_groups.begin_create_or_update(
|
||||
resource_group, nsg.name, nsg
|
||||
)
|
||||
except (ResourceNotFoundError, CloudError) as err:
|
||||
if is_concurrent_request_error(str(err)):
|
||||
logging.debug(
|
||||
"create NSG had conflicts with concurrent request, ignoring %s", err
|
||||
)
|
||||
return None
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_CREATE,
|
||||
errors=["Unable to update nsg %s due to %s" % (nsg.name, err)],
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# Return True if NSG is created using OneFuzz naming convention.
|
||||
# Therefore NSG belongs to OneFuzz.
|
||||
def ok_to_delete(active_regions: Set[Region], nsg_region: str, nsg_name: str) -> bool:
|
||||
return nsg_region not in active_regions and nsg_region == nsg_name
|
||||
|
||||
|
||||
def is_onefuzz_nsg(nsg_region: str, nsg_name: str) -> bool:
|
||||
return nsg_region == nsg_name
|
||||
|
||||
|
||||
# Returns True if deletion completed (thus resource not found) or successfully started.
|
||||
# Returns False if failed to start deletion.
|
||||
def start_delete_nsg(name: str) -> bool:
|
||||
# NSG can be only deleted if no other resource is associated with it
|
||||
resource_group = get_base_resource_group()
|
||||
|
||||
logging.info("deleting nsg: %s %s", resource_group, name)
|
||||
network_client = get_network_client()
|
||||
|
||||
try:
|
||||
network_client.network_security_groups.begin_delete(resource_group, name)
|
||||
return True
|
||||
except HttpResponseError as err:
|
||||
err_str = str(err)
|
||||
if (
|
||||
"cannot be deleted because it is in use by the following resources"
|
||||
) in err_str:
|
||||
return False
|
||||
except ResourceNotFoundError:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def set_allowed(name: str, sources: NetworkSecurityGroupConfig) -> Union[None, Error]:
|
||||
resource_group = get_base_resource_group()
|
||||
nsg = get_nsg(name)
|
||||
if not nsg:
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_FIND,
|
||||
errors=["cannot update nsg rules. nsg %s not found" % name],
|
||||
)
|
||||
|
||||
logging.info(
|
||||
"setting allowed incoming connection sources for nsg: %s %s",
|
||||
resource_group,
|
||||
name,
|
||||
)
|
||||
all_sources = sources.allowed_ips + sources.allowed_service_tags
|
||||
security_rules = []
|
||||
# NSG security rule priority range defined here:
|
||||
# https://docs.microsoft.com/en-us/azure/virtual-network/network-security-groups-overview
|
||||
min_priority = 100
|
||||
# NSG rules per NSG limits:
|
||||
# https://docs.microsoft.com/en-us/azure/azure-resource-manager/management/azure-subscription-service-limits?toc=/azure/virtual-network/toc.json#networking-limits
|
||||
max_rule_count = 1000
|
||||
if len(all_sources) > max_rule_count:
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
errors=[
|
||||
"too many rules provided %d. Max allowed: %d"
|
||||
% ((len(all_sources)), max_rule_count),
|
||||
],
|
||||
)
|
||||
|
||||
priority = min_priority
|
||||
for src in all_sources:
|
||||
security_rules.append(
|
||||
SecurityRule(
|
||||
name="Allow" + str(priority),
|
||||
protocol="*",
|
||||
source_port_range="*",
|
||||
destination_port_range="*",
|
||||
source_address_prefix=src,
|
||||
destination_address_prefix="*",
|
||||
access=SecurityRuleAccess.ALLOW,
|
||||
priority=priority, # between 100 and 4096
|
||||
direction="Inbound",
|
||||
)
|
||||
)
|
||||
# Will not exceed `max_rule_count` or max NSG priority (4096)
|
||||
# due to earlier check of `len(all_sources)`.
|
||||
priority += 1
|
||||
|
||||
nsg.security_rules = security_rules
|
||||
return update_nsg(nsg)
|
||||
|
||||
|
||||
def clear_all_rules(name: str) -> Union[None, Error]:
|
||||
return set_allowed(name, NetworkSecurityGroupConfig())
|
||||
|
||||
|
||||
def get_all_rules(name: str) -> Union[Error, List[SecurityRule]]:
|
||||
nsg = get_nsg(name)
|
||||
if not nsg:
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_FIND,
|
||||
errors=["cannot get nsg rules. nsg %s not found" % name],
|
||||
)
|
||||
|
||||
return cast(List[SecurityRule], nsg.security_rules)
|
||||
|
||||
|
||||
def associate_nic(name: str, nic: NetworkInterface) -> Union[None, Error]:
|
||||
resource_group = get_base_resource_group()
|
||||
nsg = get_nsg(name)
|
||||
if not nsg:
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_FIND,
|
||||
errors=["cannot associate nic. nsg %s not found" % name],
|
||||
)
|
||||
|
||||
if nsg.location != nic.location:
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_UPDATE,
|
||||
errors=[
|
||||
"network interface and nsg have to be in the same region.",
|
||||
"nsg %s %s, nic: %s %s"
|
||||
% (nsg.name, nsg.location, nic.name, nic.location),
|
||||
],
|
||||
)
|
||||
|
||||
if nic.network_security_group and nic.network_security_group.id == nsg.id:
|
||||
logging.info(
|
||||
"NIC %s and NSG %s already associated, not updating", nic.name, name
|
||||
)
|
||||
return None
|
||||
|
||||
logging.info("associating nic %s with nsg: %s %s", nic.name, resource_group, name)
|
||||
|
||||
nic.network_security_group = nsg
|
||||
network_client = get_network_client()
|
||||
try:
|
||||
network_client.network_interfaces.begin_create_or_update(
|
||||
resource_group, nic.name, nic
|
||||
)
|
||||
except (ResourceNotFoundError, CloudError) as err:
|
||||
if is_concurrent_request_error(str(err)):
|
||||
logging.debug(
|
||||
"associate NSG with NIC had conflicts",
|
||||
"with concurrent request, ignoring %s",
|
||||
err,
|
||||
)
|
||||
return None
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_UPDATE,
|
||||
errors=[
|
||||
"Unable to associate nsg %s with nic %s due to %s"
|
||||
% (
|
||||
name,
|
||||
nic.name,
|
||||
err,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def dissociate_nic(name: str, nic: NetworkInterface) -> Union[None, Error]:
|
||||
if nic.network_security_group is None:
|
||||
return None
|
||||
resource_group = get_base_resource_group()
|
||||
nsg = get_nsg(name)
|
||||
if not nsg:
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_FIND,
|
||||
errors=["cannot update nsg rules. nsg %s not found" % name],
|
||||
)
|
||||
if nsg.id != nic.network_security_group.id:
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_UPDATE,
|
||||
errors=[
|
||||
"network interface is not associated with this nsg.",
|
||||
"nsg %s, nic: %s, nic.nsg: %s"
|
||||
% (
|
||||
nsg.id,
|
||||
nic.name,
|
||||
nic.network_security_group.id,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
logging.info("dissociating nic %s with nsg: %s %s", nic.name, resource_group, name)
|
||||
|
||||
nic.network_security_group = None
|
||||
network_client = get_network_client()
|
||||
try:
|
||||
network_client.network_interfaces.begin_create_or_update(
|
||||
resource_group, nic.name, nic
|
||||
)
|
||||
except (ResourceNotFoundError, CloudError) as err:
|
||||
if is_concurrent_request_error(str(err)):
|
||||
logging.debug(
|
||||
"dissociate nsg with nic had conflicts with ",
|
||||
"concurrent request, ignoring %s",
|
||||
err,
|
||||
)
|
||||
return None
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_UPDATE,
|
||||
errors=[
|
||||
"Unable to dissociate nsg %s with nic %s due to %s"
|
||||
% (
|
||||
name,
|
||||
nic.name,
|
||||
err,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def associate_subnet(
|
||||
name: str, vnet: VirtualNetwork, subnet: Subnet
|
||||
) -> Union[None, Error]:
|
||||
|
||||
resource_group = get_base_resource_group()
|
||||
nsg = get_nsg(name)
|
||||
if not nsg:
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_FIND,
|
||||
errors=["cannot associate subnet. nsg %s not found" % name],
|
||||
)
|
||||
|
||||
if nsg.location != vnet.location:
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_UPDATE,
|
||||
errors=[
|
||||
"subnet and nsg have to be in the same region.",
|
||||
"nsg %s %s, subnet: %s %s"
|
||||
% (nsg.name, nsg.location, subnet.name, subnet.location),
|
||||
],
|
||||
)
|
||||
|
||||
if subnet.network_security_group and subnet.network_security_group.id == nsg.id:
|
||||
logging.info(
|
||||
"Subnet %s and NSG %s already associated, not updating", subnet.name, name
|
||||
)
|
||||
return None
|
||||
|
||||
logging.info(
|
||||
"associating subnet %s with nsg: %s %s", subnet.name, resource_group, name
|
||||
)
|
||||
|
||||
subnet.network_security_group = nsg
|
||||
network_client = get_network_client()
|
||||
try:
|
||||
network_client.subnets.begin_create_or_update(
|
||||
resource_group, vnet.name, subnet.name, subnet
|
||||
)
|
||||
except (ResourceNotFoundError, CloudError) as err:
|
||||
if is_concurrent_request_error(str(err)):
|
||||
logging.debug(
|
||||
"associate NSG with subnet had conflicts",
|
||||
"with concurrent request, ignoring %s",
|
||||
err,
|
||||
)
|
||||
return None
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_UPDATE,
|
||||
errors=[
|
||||
"Unable to associate nsg %s with subnet %s due to %s"
|
||||
% (
|
||||
name,
|
||||
subnet.name,
|
||||
err,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def dissociate_subnet(
|
||||
name: str, vnet: VirtualNetwork, subnet: Subnet
|
||||
) -> Union[None, Error]:
|
||||
if subnet.network_security_group is None:
|
||||
return None
|
||||
resource_group = get_base_resource_group()
|
||||
nsg = get_nsg(name)
|
||||
if not nsg:
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_FIND,
|
||||
errors=["cannot update nsg rules. nsg %s not found" % name],
|
||||
)
|
||||
if nsg.id != subnet.network_security_group.id:
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_UPDATE,
|
||||
errors=[
|
||||
"subnet is not associated with this nsg.",
|
||||
"nsg %s, subnet: %s, subnet.nsg: %s"
|
||||
% (
|
||||
nsg.id,
|
||||
subnet.name,
|
||||
subnet.network_security_group.id,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
logging.info(
|
||||
"dissociating subnet %s with nsg: %s %s", subnet.name, resource_group, name
|
||||
)
|
||||
|
||||
subnet.network_security_group = None
|
||||
network_client = get_network_client()
|
||||
try:
|
||||
network_client.subnets.begin_create_or_update(
|
||||
resource_group, vnet.name, subnet.name, subnet
|
||||
)
|
||||
except (ResourceNotFoundError, CloudError) as err:
|
||||
if is_concurrent_request_error(str(err)):
|
||||
logging.debug(
|
||||
"dissociate nsg with subnet had conflicts with ",
|
||||
"concurrent request, ignoring %s",
|
||||
err,
|
||||
)
|
||||
return None
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_UPDATE,
|
||||
errors=[
|
||||
"Unable to dissociate nsg %s with subnet %s due to %s"
|
||||
% (
|
||||
name,
|
||||
subnet.name,
|
||||
err,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class NSG(BaseModel):
|
||||
name: str
|
||||
region: Region
|
||||
|
||||
@validator("name", allow_reuse=True)
|
||||
def check_name(cls, value: str) -> str:
|
||||
# https://docs.microsoft.com/en-us/azure/azure-resource-manager/management/resource-name-rules
|
||||
if len(value) > 80:
|
||||
raise ValueError("NSG name too long")
|
||||
return value
|
||||
|
||||
def create(self) -> Union[None, Error]:
|
||||
# Optimization: if NSG exists - do not try
|
||||
# to create it
|
||||
if self.get() is not None:
|
||||
return None
|
||||
|
||||
return create_nsg(self.name, self.region)
|
||||
|
||||
def start_delete(self) -> bool:
|
||||
return start_delete_nsg(self.name)
|
||||
|
||||
def get(self) -> Optional[NetworkSecurityGroup]:
|
||||
return get_nsg(self.name)
|
||||
|
||||
def set_allowed_sources(
|
||||
self, sources: NetworkSecurityGroupConfig
|
||||
) -> Union[None, Error]:
|
||||
return set_allowed(self.name, sources)
|
||||
|
||||
def clear_all_rules(self) -> Union[None, Error]:
|
||||
return clear_all_rules(self.name)
|
||||
|
||||
def get_all_rules(self) -> Union[Error, List[SecurityRule]]:
|
||||
return get_all_rules(self.name)
|
||||
|
||||
def associate_nic(self, nic: NetworkInterface) -> Union[None, Error]:
|
||||
return associate_nic(self.name, nic)
|
||||
|
||||
def dissociate_nic(self, nic: NetworkInterface) -> Union[None, Error]:
|
||||
return dissociate_nic(self.name, nic)
|
||||
|
||||
def associate_subnet(
|
||||
self, vnet: VirtualNetwork, subnet: Subnet
|
||||
) -> Union[None, Error]:
|
||||
return associate_subnet(self.name, vnet, subnet)
|
||||
|
||||
def dissociate_subnet(
|
||||
self, vnet: VirtualNetwork, subnet: Subnet
|
||||
) -> Union[None, Error]:
|
||||
return dissociate_subnet(self.name, vnet, subnet)
|
@ -1,203 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import base64
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional, Type, TypeVar, Union
|
||||
from uuid import UUID
|
||||
|
||||
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
|
||||
from azure.storage.queue import (
|
||||
QueueSasPermissions,
|
||||
QueueServiceClient,
|
||||
generate_queue_sas,
|
||||
)
|
||||
from memoization import cached
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .storage import StorageType, get_primary_account, get_storage_account_name_key
|
||||
|
||||
QueueNameType = Union[str, UUID]
|
||||
|
||||
DEFAULT_TTL = -1
|
||||
DEFAULT_DURATION = datetime.timedelta(days=30)
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_queue_client(storage_type: StorageType) -> QueueServiceClient:
|
||||
account_id = get_primary_account(storage_type)
|
||||
logging.debug("getting blob container (account_id: %s)", account_id)
|
||||
name, key = get_storage_account_name_key(account_id)
|
||||
account_url = "https://%s.queue.core.windows.net" % name
|
||||
client = QueueServiceClient(
|
||||
account_url=account_url,
|
||||
credential={"account_name": name, "account_key": key},
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_queue_sas(
|
||||
queue: QueueNameType,
|
||||
storage_type: StorageType,
|
||||
*,
|
||||
read: bool = False,
|
||||
add: bool = False,
|
||||
update: bool = False,
|
||||
process: bool = False,
|
||||
duration: Optional[datetime.timedelta] = None,
|
||||
) -> str:
|
||||
if duration is None:
|
||||
duration = DEFAULT_DURATION
|
||||
account_id = get_primary_account(storage_type)
|
||||
logging.debug("getting queue sas %s (account_id: %s)", queue, account_id)
|
||||
name, key = get_storage_account_name_key(account_id)
|
||||
expiry = datetime.datetime.utcnow() + duration
|
||||
|
||||
token = generate_queue_sas(
|
||||
name,
|
||||
str(queue),
|
||||
key,
|
||||
permission=QueueSasPermissions(
|
||||
read=read, add=add, update=update, process=process
|
||||
),
|
||||
expiry=expiry,
|
||||
)
|
||||
|
||||
url = "https://%s.queue.core.windows.net/%s?%s" % (name, queue, token)
|
||||
return url
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def create_queue(name: QueueNameType, storage_type: StorageType) -> None:
|
||||
client = get_queue_client(storage_type)
|
||||
try:
|
||||
client.create_queue(str(name))
|
||||
except ResourceExistsError:
|
||||
pass
|
||||
|
||||
|
||||
def delete_queue(name: QueueNameType, storage_type: StorageType) -> None:
|
||||
client = get_queue_client(storage_type)
|
||||
queues = client.list_queues()
|
||||
if str(name) in [x["name"] for x in queues]:
|
||||
client.delete_queue(name)
|
||||
|
||||
|
||||
def get_queue(
|
||||
name: QueueNameType, storage_type: StorageType
|
||||
) -> Optional[QueueServiceClient]:
|
||||
client = get_queue_client(storage_type)
|
||||
try:
|
||||
return client.get_queue_client(str(name))
|
||||
except ResourceNotFoundError:
|
||||
return None
|
||||
|
||||
|
||||
def clear_queue(name: QueueNameType, storage_type: StorageType) -> None:
|
||||
queue = get_queue(name, storage_type)
|
||||
if queue:
|
||||
try:
|
||||
queue.clear_messages()
|
||||
except ResourceNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
def send_message(
|
||||
name: QueueNameType,
|
||||
message: bytes,
|
||||
storage_type: StorageType,
|
||||
*,
|
||||
visibility_timeout: Optional[int] = None,
|
||||
time_to_live: int = DEFAULT_TTL,
|
||||
) -> None:
|
||||
queue = get_queue(name, storage_type)
|
||||
if queue:
|
||||
try:
|
||||
queue.send_message(
|
||||
base64.b64encode(message).decode(),
|
||||
visibility_timeout=visibility_timeout,
|
||||
time_to_live=time_to_live,
|
||||
)
|
||||
except ResourceNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
def remove_first_message(name: QueueNameType, storage_type: StorageType) -> bool:
|
||||
queue = get_queue(name, storage_type)
|
||||
if queue:
|
||||
try:
|
||||
for message in queue.receive_messages():
|
||||
queue.delete_message(message)
|
||||
return True
|
||||
except ResourceNotFoundError:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
A = TypeVar("A", bound=BaseModel)
|
||||
|
||||
|
||||
MIN_PEEK_SIZE = 1
|
||||
MAX_PEEK_SIZE = 32
|
||||
|
||||
|
||||
# Peek at a max of 32 messages
|
||||
# https://docs.microsoft.com/en-us/python/api/azure-storage-queue/azure.storage.queue.queueclient
|
||||
def peek_queue(
|
||||
name: QueueNameType,
|
||||
storage_type: StorageType,
|
||||
*,
|
||||
object_type: Type[A],
|
||||
max_messages: int = MAX_PEEK_SIZE,
|
||||
) -> List[A]:
|
||||
result: List[A] = []
|
||||
|
||||
# message count
|
||||
if max_messages < MIN_PEEK_SIZE or max_messages > MAX_PEEK_SIZE:
|
||||
raise ValueError("invalid max messages: %s" % max_messages)
|
||||
|
||||
try:
|
||||
queue = get_queue(name, storage_type)
|
||||
if not queue:
|
||||
return result
|
||||
|
||||
for message in queue.peek_messages(max_messages=max_messages):
|
||||
decoded = base64.b64decode(message.content)
|
||||
raw = json.loads(decoded)
|
||||
result.append(object_type.parse_obj(raw))
|
||||
except ResourceNotFoundError:
|
||||
return result
|
||||
return result
|
||||
|
||||
|
||||
def queue_object(
|
||||
name: QueueNameType,
|
||||
message: BaseModel,
|
||||
storage_type: StorageType,
|
||||
*,
|
||||
visibility_timeout: Optional[int] = None,
|
||||
time_to_live: int = DEFAULT_TTL,
|
||||
) -> bool:
|
||||
queue = get_queue(name, storage_type)
|
||||
if not queue:
|
||||
raise Exception("unable to queue object, no such queue: %s" % queue)
|
||||
|
||||
encoded = base64.b64encode(message.json(exclude_none=True).encode()).decode()
|
||||
try:
|
||||
queue.send_message(
|
||||
encoded, visibility_timeout=visibility_timeout, time_to_live=time_to_live
|
||||
)
|
||||
return True
|
||||
except ResourceNotFoundError:
|
||||
return False
|
||||
|
||||
|
||||
def get_resource_id(queue_name: QueueNameType, storage_type: StorageType) -> str:
|
||||
account_id = get_primary_account(storage_type)
|
||||
resource_uri = "%s/services/queue/queues/%s" % (account_id, queue_name)
|
||||
return resource_uri
|
@ -1,120 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from enum import Enum
|
||||
from typing import List, Tuple, cast
|
||||
|
||||
from azure.mgmt.storage import StorageManagementClient
|
||||
from memoization import cached
|
||||
from msrestazure.tools import parse_resource_id
|
||||
|
||||
from .creds import get_base_resource_group, get_identity, get_subscription
|
||||
|
||||
|
||||
class StorageType(Enum):
|
||||
corpus = "corpus"
|
||||
config = "config"
|
||||
|
||||
|
||||
@cached
|
||||
def get_mgmt_client() -> StorageManagementClient:
|
||||
return StorageManagementClient(
|
||||
credential=get_identity(), subscription_id=get_subscription()
|
||||
)
|
||||
|
||||
|
||||
@cached
|
||||
def get_fuzz_storage() -> str:
|
||||
return os.environ["ONEFUZZ_DATA_STORAGE"]
|
||||
|
||||
|
||||
@cached
|
||||
def get_func_storage() -> str:
|
||||
return os.environ["ONEFUZZ_FUNC_STORAGE"]
|
||||
|
||||
|
||||
@cached
|
||||
def get_primary_account(storage_type: StorageType) -> str:
|
||||
if storage_type == StorageType.corpus:
|
||||
# see #322 for discussion about typing
|
||||
return get_fuzz_storage()
|
||||
elif storage_type == StorageType.config:
|
||||
# see #322 for discussion about typing
|
||||
return get_func_storage()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@cached
|
||||
def get_accounts(storage_type: StorageType) -> List[str]:
|
||||
if storage_type == StorageType.corpus:
|
||||
return corpus_accounts()
|
||||
elif storage_type == StorageType.config:
|
||||
return [get_func_storage()]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@cached
|
||||
def get_storage_account_name_key(account_id: str) -> Tuple[str, str]:
|
||||
resource = parse_resource_id(account_id)
|
||||
key = get_storage_account_name_key_by_name(resource["name"])
|
||||
return resource["name"], key
|
||||
|
||||
|
||||
@cached
|
||||
def get_storage_account_name_key_by_name(account_name: str) -> str:
|
||||
client = get_mgmt_client()
|
||||
group = get_base_resource_group()
|
||||
key = client.storage_accounts.list_keys(group, account_name).keys[0].value
|
||||
return cast(str, key)
|
||||
|
||||
|
||||
def choose_account(storage_type: StorageType) -> str:
|
||||
accounts = get_accounts(storage_type)
|
||||
if not accounts:
|
||||
raise Exception(f"no storage accounts for {storage_type}")
|
||||
|
||||
if len(accounts) == 1:
|
||||
return accounts[0]
|
||||
|
||||
# Use a random secondary storage account if any are available. This
|
||||
# reduces IOP contention for the Storage Queues, which are only available
|
||||
# on primary accounts
|
||||
#
|
||||
# security note: this is not used as a security feature
|
||||
return random.choice(accounts[1:]) # nosec
|
||||
|
||||
|
||||
@cached
|
||||
def corpus_accounts() -> List[str]:
|
||||
skip = get_func_storage()
|
||||
results = [get_fuzz_storage()]
|
||||
|
||||
client = get_mgmt_client()
|
||||
group = get_base_resource_group()
|
||||
for account in client.storage_accounts.list_by_resource_group(group):
|
||||
# protection from someone adding the corpus tag to the config account
|
||||
if account.id == skip:
|
||||
continue
|
||||
|
||||
if account.id in results:
|
||||
continue
|
||||
|
||||
if account.primary_endpoints.blob is None:
|
||||
continue
|
||||
|
||||
if (
|
||||
"storage_type" not in account.tags
|
||||
or account.tags["storage_type"] != "corpus"
|
||||
):
|
||||
continue
|
||||
|
||||
results.append(account.id)
|
||||
|
||||
logging.info("corpus accounts: %s", results)
|
||||
return results
|
@ -1,97 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from azure.mgmt.network.models import Subnet, VirtualNetwork
|
||||
from msrestazure.azure_exceptions import CloudError
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error, NetworkConfig
|
||||
from onefuzztypes.primitives import Region
|
||||
|
||||
from .network_mgmt_client import get_network_client
|
||||
|
||||
|
||||
def get_vnet(resource_group: str, name: str) -> Optional[VirtualNetwork]:
|
||||
network_client = get_network_client()
|
||||
try:
|
||||
vnet = network_client.virtual_networks.get(resource_group, name)
|
||||
return cast(VirtualNetwork, vnet)
|
||||
except (CloudError, ResourceNotFoundError):
|
||||
logging.info(
|
||||
"vnet missing: resource group:%s name:%s",
|
||||
resource_group,
|
||||
name,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def get_subnet(
|
||||
resource_group: str, vnet_name: str, subnet_name: str
|
||||
) -> Optional[Subnet]:
|
||||
# Has to get using vnet. That way NSG field is properly set up in subnet
|
||||
vnet = get_vnet(resource_group, vnet_name)
|
||||
if vnet:
|
||||
for subnet in vnet.subnets:
|
||||
if subnet.name == subnet_name:
|
||||
return subnet
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_subnet_id(resource_group: str, name: str, subnet_name: str) -> Optional[str]:
|
||||
subnet = get_subnet(resource_group, name, subnet_name)
|
||||
if subnet and isinstance(subnet.id, str):
|
||||
return subnet.id
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def delete_subnet(resource_group: str, name: str) -> Union[None, CloudError, Any]:
|
||||
network_client = get_network_client()
|
||||
try:
|
||||
return network_client.virtual_networks.begin_delete(resource_group, name)
|
||||
except (CloudError, ResourceNotFoundError) as err:
|
||||
if err.error and "InUseSubnetCannotBeDeleted" in str(err.error):
|
||||
logging.error(
|
||||
"subnet delete failed: %s %s : %s", resource_group, name, repr(err)
|
||||
)
|
||||
return None
|
||||
raise err
|
||||
|
||||
|
||||
def create_virtual_network(
|
||||
resource_group: str,
|
||||
name: str,
|
||||
region: Region,
|
||||
network_config: NetworkConfig,
|
||||
) -> Optional[Error]:
|
||||
logging.info(
|
||||
"creating subnet - resource group:%s name:%s region:%s",
|
||||
resource_group,
|
||||
name,
|
||||
region,
|
||||
)
|
||||
|
||||
network_client = get_network_client()
|
||||
params = {
|
||||
"location": region,
|
||||
"address_space": {"address_prefixes": [network_config.address_space]},
|
||||
"subnets": [{"name": name, "address_prefix": network_config.subnet}],
|
||||
}
|
||||
if "ONEFUZZ_OWNER" in os.environ:
|
||||
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
|
||||
|
||||
try:
|
||||
network_client.virtual_networks.begin_create_or_update(
|
||||
resource_group, name, params
|
||||
)
|
||||
except (CloudError, ResourceNotFoundError) as err:
|
||||
return Error(code=ErrorCode.UNABLE_TO_CREATE_NETWORK, errors=[str(err)])
|
||||
|
||||
return None
|
@ -1,30 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from azure.cosmosdb.table import TableService
|
||||
from memoization import cached
|
||||
|
||||
from .storage import get_storage_account_name_key
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_client(
|
||||
table: Optional[str] = None, account_id: Optional[str] = None
|
||||
) -> TableService:
|
||||
if account_id is None:
|
||||
account_id = os.environ["ONEFUZZ_FUNC_STORAGE"]
|
||||
|
||||
logging.debug("getting table account: (account_id: %s)", account_id)
|
||||
name, key = get_storage_account_name_key(account_id)
|
||||
client = TableService(account_name=name, account_key=key)
|
||||
|
||||
if table and not client.exists(table):
|
||||
logging.info("creating missing table %s", table)
|
||||
client.create_table(table, fail_on_exist=False)
|
||||
return client
|
@ -1,313 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
from uuid import UUID
|
||||
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from azure.mgmt.compute.models import VirtualMachine
|
||||
from msrestazure.azure_exceptions import CloudError
|
||||
from onefuzztypes.enums import OS, ErrorCode
|
||||
from onefuzztypes.models import Authentication, Error
|
||||
from onefuzztypes.primitives import Extension, Region
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from .compute import get_compute_client
|
||||
from .creds import get_base_resource_group
|
||||
from .disk import delete_disk, list_disks
|
||||
from .image import get_os
|
||||
from .ip import create_public_nic, delete_ip, delete_nic, get_ip, get_public_nic
|
||||
from .nsg import NSG
|
||||
|
||||
|
||||
def get_vm(name: str) -> Optional[VirtualMachine]:
|
||||
resource_group = get_base_resource_group()
|
||||
|
||||
logging.debug("getting vm: %s", name)
|
||||
compute_client = get_compute_client()
|
||||
try:
|
||||
return cast(
|
||||
VirtualMachine,
|
||||
compute_client.virtual_machines.get(
|
||||
resource_group, name, expand="instanceView"
|
||||
),
|
||||
)
|
||||
except (ResourceNotFoundError, CloudError) as err:
|
||||
logging.debug("vm does not exist %s", err)
|
||||
return None
|
||||
|
||||
|
||||
def create_vm(
|
||||
name: str,
|
||||
location: Region,
|
||||
vm_sku: str,
|
||||
image: str,
|
||||
password: str,
|
||||
ssh_public_key: str,
|
||||
nsg: Optional[NSG],
|
||||
tags: Optional[Dict[str, str]],
|
||||
) -> Union[None, Error]:
|
||||
|
||||
resource_group = get_base_resource_group()
|
||||
logging.info("creating vm %s:%s:%s", resource_group, location, name)
|
||||
|
||||
compute_client = get_compute_client()
|
||||
|
||||
nic = get_public_nic(resource_group, name)
|
||||
if nic is None:
|
||||
result = create_public_nic(resource_group, name, location, nsg)
|
||||
if isinstance(result, Error):
|
||||
return result
|
||||
logging.info("waiting on nic creation")
|
||||
return None
|
||||
|
||||
# when public nic is created, VNET must exist at that point
|
||||
# this is logic of get_public_nic function
|
||||
|
||||
if nsg:
|
||||
result = nsg.associate_nic(nic)
|
||||
if isinstance(result, Error):
|
||||
return result
|
||||
|
||||
if image.startswith("/"):
|
||||
image_ref = {"id": image}
|
||||
else:
|
||||
image_val = image.split(":", 4)
|
||||
image_ref = {
|
||||
"publisher": image_val[0],
|
||||
"offer": image_val[1],
|
||||
"sku": image_val[2],
|
||||
"version": image_val[3],
|
||||
}
|
||||
|
||||
params: Dict = {
|
||||
"location": location,
|
||||
"os_profile": {
|
||||
"computer_name": "node",
|
||||
"admin_username": "onefuzz",
|
||||
},
|
||||
"hardware_profile": {"vm_size": vm_sku},
|
||||
"storage_profile": {"image_reference": image_ref},
|
||||
"network_profile": {"network_interfaces": [{"id": nic.id}]},
|
||||
}
|
||||
|
||||
image_os = get_os(location, image)
|
||||
if isinstance(image_os, Error):
|
||||
return image_os
|
||||
|
||||
if image_os == OS.windows:
|
||||
params["os_profile"]["admin_password"] = password
|
||||
|
||||
if image_os == OS.linux:
|
||||
|
||||
params["os_profile"]["linux_configuration"] = {
|
||||
"disable_password_authentication": True,
|
||||
"ssh": {
|
||||
"public_keys": [
|
||||
{
|
||||
"path": "/home/onefuzz/.ssh/authorized_keys",
|
||||
"key_data": ssh_public_key,
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
if "ONEFUZZ_OWNER" in os.environ:
|
||||
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
|
||||
|
||||
if tags:
|
||||
params["tags"].update(tags.copy())
|
||||
|
||||
try:
|
||||
compute_client.virtual_machines.begin_create_or_update(
|
||||
resource_group, name, params
|
||||
)
|
||||
except (ResourceNotFoundError, CloudError) as err:
|
||||
if "The request failed due to conflict with a concurrent request" in str(err):
|
||||
logging.debug(
|
||||
"create VM had conflicts with concurrent request, ignoring %s", err
|
||||
)
|
||||
return None
|
||||
return Error(code=ErrorCode.VM_CREATE_FAILED, errors=[str(err)])
|
||||
return None
|
||||
|
||||
|
||||
def get_extension(vm_name: str, extension_name: str) -> Optional[Any]:
|
||||
resource_group = get_base_resource_group()
|
||||
|
||||
logging.debug(
|
||||
"getting extension: %s:%s:%s",
|
||||
resource_group,
|
||||
vm_name,
|
||||
extension_name,
|
||||
)
|
||||
compute_client = get_compute_client()
|
||||
try:
|
||||
return compute_client.virtual_machine_extensions.get(
|
||||
resource_group, vm_name, extension_name
|
||||
)
|
||||
except (ResourceNotFoundError, CloudError) as err:
|
||||
logging.info("extension does not exist %s", err)
|
||||
return None
|
||||
|
||||
|
||||
def create_extension(vm_name: str, extension: Dict) -> Any:
|
||||
resource_group = get_base_resource_group()
|
||||
|
||||
logging.info(
|
||||
"creating extension: %s:%s:%s", resource_group, vm_name, extension["name"]
|
||||
)
|
||||
compute_client = get_compute_client()
|
||||
return compute_client.virtual_machine_extensions.begin_create_or_update(
|
||||
resource_group, vm_name, extension["name"], extension
|
||||
)
|
||||
|
||||
|
||||
def delete_vm(name: str) -> Any:
|
||||
resource_group = get_base_resource_group()
|
||||
|
||||
logging.info("deleting vm: %s %s", resource_group, name)
|
||||
compute_client = get_compute_client()
|
||||
return compute_client.virtual_machines.begin_delete(resource_group, name)
|
||||
|
||||
|
||||
def has_components(name: str) -> bool:
|
||||
# check if any of the components associated with a VM still exist.
|
||||
#
|
||||
# Azure VM Deletion requires we first delete the VM, then delete all of it's
|
||||
# resources. This is required to ensure we've cleaned it all up before
|
||||
# marking it "done"
|
||||
resource_group = get_base_resource_group()
|
||||
if get_vm(name):
|
||||
return True
|
||||
if get_public_nic(resource_group, name):
|
||||
return True
|
||||
if get_ip(resource_group, name):
|
||||
return True
|
||||
|
||||
disks = [x.name for x in list_disks(resource_group) if x.name.startswith(name)]
|
||||
if disks:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def delete_vm_components(name: str, nsg: Optional[NSG]) -> bool:
|
||||
resource_group = get_base_resource_group()
|
||||
logging.info("deleting vm components %s:%s", resource_group, name)
|
||||
if get_vm(name):
|
||||
logging.info("deleting vm %s:%s", resource_group, name)
|
||||
delete_vm(name)
|
||||
return False
|
||||
|
||||
nic = get_public_nic(resource_group, name)
|
||||
if nic:
|
||||
logging.info("deleting nic %s:%s", resource_group, name)
|
||||
if nic.network_security_group and nsg:
|
||||
nsg.dissociate_nic(nic)
|
||||
return False
|
||||
delete_nic(resource_group, name)
|
||||
return False
|
||||
|
||||
if get_ip(resource_group, name):
|
||||
logging.info("deleting ip %s:%s", resource_group, name)
|
||||
delete_ip(resource_group, name)
|
||||
return False
|
||||
|
||||
disks = [x.name for x in list_disks(resource_group) if x.name.startswith(name)]
|
||||
if disks:
|
||||
for disk in disks:
|
||||
logging.info("deleting disk %s:%s", resource_group, disk)
|
||||
delete_disk(resource_group, disk)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class VM(BaseModel):
|
||||
name: Union[UUID, str]
|
||||
region: Region
|
||||
sku: str
|
||||
image: str
|
||||
auth: Authentication
|
||||
nsg: Optional[NSG]
|
||||
tags: Optional[Dict[str, str]]
|
||||
|
||||
@validator("name", allow_reuse=True)
|
||||
def check_name(cls, value: Union[UUID, str]) -> Union[UUID, str]:
|
||||
if isinstance(value, str):
|
||||
if len(value) > 40:
|
||||
# Azure truncates resources if the names are longer than 40
|
||||
# bytes
|
||||
raise ValueError("VM name too long")
|
||||
return value
|
||||
|
||||
def is_deleted(self) -> bool:
|
||||
# A VM is considered deleted once all of it's resources including disks,
|
||||
# NICs, IPs, as well as the VM are deleted
|
||||
return not has_components(str(self.name))
|
||||
|
||||
def exists(self) -> bool:
|
||||
return self.get() is not None
|
||||
|
||||
def get(self) -> Optional[VirtualMachine]:
|
||||
return get_vm(str(self.name))
|
||||
|
||||
def create(self) -> Union[None, Error]:
|
||||
if self.get() is not None:
|
||||
return None
|
||||
|
||||
logging.info("vm creating: %s", self.name)
|
||||
return create_vm(
|
||||
str(self.name),
|
||||
self.region,
|
||||
self.sku,
|
||||
self.image,
|
||||
self.auth.password,
|
||||
self.auth.public_key,
|
||||
self.nsg,
|
||||
self.tags,
|
||||
)
|
||||
|
||||
def delete(self) -> bool:
|
||||
return delete_vm_components(str(self.name), self.nsg)
|
||||
|
||||
def add_extensions(self, extensions: List[Extension]) -> Union[bool, Error]:
|
||||
status = []
|
||||
to_create = []
|
||||
for config in extensions:
|
||||
if not isinstance(config["name"], str):
|
||||
logging.error("vm agent - incompatable name: %s", repr(config))
|
||||
continue
|
||||
extension = get_extension(str(self.name), config["name"])
|
||||
|
||||
if extension:
|
||||
logging.info(
|
||||
"vm extension state: %s - %s - %s",
|
||||
self.name,
|
||||
config["name"],
|
||||
extension.provisioning_state,
|
||||
)
|
||||
status.append(extension.provisioning_state)
|
||||
else:
|
||||
to_create.append(config)
|
||||
|
||||
if to_create:
|
||||
for config in to_create:
|
||||
create_extension(str(self.name), config)
|
||||
else:
|
||||
if all([x == "Succeeded" for x in status]):
|
||||
return True
|
||||
elif "Failed" in status:
|
||||
return Error(
|
||||
code=ErrorCode.VM_CREATE_FAILED,
|
||||
errors=["failed to launch extension"],
|
||||
)
|
||||
elif not ("Creating" in status or "Updating" in status):
|
||||
logging.error("vm agent - unknown state %s: %s", self.name, status)
|
||||
|
||||
return False
|
@ -1,541 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union, cast
|
||||
from uuid import UUID
|
||||
|
||||
from azure.core.exceptions import (
|
||||
HttpResponseError,
|
||||
ResourceExistsError,
|
||||
ResourceNotFoundError,
|
||||
)
|
||||
from azure.mgmt.compute.models import (
|
||||
ResourceSku,
|
||||
ResourceSkuRestrictionsType,
|
||||
VirtualMachineScaleSetVMInstanceIDs,
|
||||
VirtualMachineScaleSetVMInstanceRequiredIDs,
|
||||
VirtualMachineScaleSetVMListResult,
|
||||
VirtualMachineScaleSetVMProtectionPolicy,
|
||||
)
|
||||
from memoization import cached
|
||||
from msrestazure.azure_exceptions import CloudError
|
||||
from onefuzztypes.enums import OS, ErrorCode
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.primitives import Region
|
||||
|
||||
from .compute import get_compute_client
|
||||
from .creds import (
|
||||
get_base_resource_group,
|
||||
get_scaleset_identity_resource_path,
|
||||
retry_on_auth_failure,
|
||||
)
|
||||
from .image import get_os
|
||||
|
||||
|
||||
@retry_on_auth_failure()
|
||||
def list_vmss(
|
||||
name: UUID,
|
||||
vm_filter: Optional[Callable[[VirtualMachineScaleSetVMListResult], bool]] = None,
|
||||
) -> Optional[List[str]]:
|
||||
resource_group = get_base_resource_group()
|
||||
client = get_compute_client()
|
||||
try:
|
||||
instances = [
|
||||
x.instance_id
|
||||
for x in client.virtual_machine_scale_set_vms.list(
|
||||
resource_group, str(name)
|
||||
)
|
||||
if vm_filter is None or vm_filter(x)
|
||||
]
|
||||
return instances
|
||||
except (ResourceNotFoundError, CloudError) as err:
|
||||
logging.error("cloud error listing vmss: %s (%s)", name, err)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@retry_on_auth_failure()
|
||||
def delete_vmss(name: UUID) -> bool:
|
||||
resource_group = get_base_resource_group()
|
||||
compute_client = get_compute_client()
|
||||
response = compute_client.virtual_machine_scale_sets.begin_delete(
|
||||
resource_group, str(name)
|
||||
)
|
||||
|
||||
# https://docs.microsoft.com/en-us/python/api/azure-core/
|
||||
# azure.core.polling.lropoller?view=azure-python#status--
|
||||
#
|
||||
# status returns a str, however mypy thinks this is an Any.
|
||||
#
|
||||
# Checked by hand that the result is Succeeded in practice
|
||||
return bool(response.status() == "Succeeded")
|
||||
|
||||
|
||||
@retry_on_auth_failure()
|
||||
def get_vmss(name: UUID) -> Optional[Any]:
|
||||
resource_group = get_base_resource_group()
|
||||
logging.debug("getting vm: %s", name)
|
||||
compute_client = get_compute_client()
|
||||
try:
|
||||
return compute_client.virtual_machine_scale_sets.get(resource_group, str(name))
|
||||
except ResourceNotFoundError as err:
|
||||
logging.debug("vm does not exist %s", err)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@retry_on_auth_failure()
|
||||
def resize_vmss(name: UUID, capacity: int) -> None:
|
||||
check_can_update(name)
|
||||
|
||||
resource_group = get_base_resource_group()
|
||||
logging.info("updating VM count - name: %s vm_count: %d", name, capacity)
|
||||
compute_client = get_compute_client()
|
||||
try:
|
||||
compute_client.virtual_machine_scale_sets.begin_update(
|
||||
resource_group, str(name), {"sku": {"capacity": capacity}}
|
||||
)
|
||||
except ResourceExistsError as err:
|
||||
logging.error(
|
||||
"unable to resize scaleset. name:%s vm_count:%d - err:%s",
|
||||
name,
|
||||
capacity,
|
||||
err,
|
||||
)
|
||||
except HttpResponseError as err:
|
||||
if (
|
||||
"models that may be referenced by one or more"
|
||||
+ " VMs belonging to the Virtual Machine Scale Set"
|
||||
in str(err)
|
||||
):
|
||||
logging.error(
|
||||
"unable to resize scaleset due to model error. name: %s - err:%s",
|
||||
name,
|
||||
err,
|
||||
)
|
||||
|
||||
|
||||
@retry_on_auth_failure()
|
||||
def get_vmss_size(name: UUID) -> Optional[int]:
|
||||
vmss = get_vmss(name)
|
||||
if vmss is None:
|
||||
return None
|
||||
return cast(int, vmss.sku.capacity)
|
||||
|
||||
|
||||
@retry_on_auth_failure()
|
||||
def list_instance_ids(name: UUID) -> Dict[UUID, str]:
|
||||
logging.debug("get instance IDs for scaleset: %s", name)
|
||||
resource_group = get_base_resource_group()
|
||||
compute_client = get_compute_client()
|
||||
|
||||
results = {}
|
||||
try:
|
||||
for instance in compute_client.virtual_machine_scale_set_vms.list(
|
||||
resource_group, str(name)
|
||||
):
|
||||
results[UUID(instance.vm_id)] = cast(str, instance.instance_id)
|
||||
except (ResourceNotFoundError, CloudError):
|
||||
logging.debug("vm does not exist %s", name)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
@retry_on_auth_failure()
|
||||
def get_instance_id(name: UUID, vm_id: UUID) -> Union[str, Error]:
|
||||
resource_group = get_base_resource_group()
|
||||
logging.info("get instance ID for scaleset node: %s:%s", name, vm_id)
|
||||
compute_client = get_compute_client()
|
||||
|
||||
vm_id_str = str(vm_id)
|
||||
for instance in compute_client.virtual_machine_scale_set_vms.list(
|
||||
resource_group, str(name)
|
||||
):
|
||||
if instance.vm_id == vm_id_str:
|
||||
return cast(str, instance.instance_id)
|
||||
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_FIND,
|
||||
errors=["unable to find scaleset machine: %s:%s" % (name, vm_id)],
|
||||
)
|
||||
|
||||
|
||||
@retry_on_auth_failure()
|
||||
def update_scale_in_protection(
|
||||
name: UUID, vm_id: UUID, protect_from_scale_in: bool
|
||||
) -> Optional[Error]:
|
||||
instance_id = get_instance_id(name, vm_id)
|
||||
|
||||
if isinstance(instance_id, Error):
|
||||
return instance_id
|
||||
|
||||
compute_client = get_compute_client()
|
||||
resource_group = get_base_resource_group()
|
||||
|
||||
try:
|
||||
instance_vm = compute_client.virtual_machine_scale_set_vms.get(
|
||||
resource_group, name, instance_id
|
||||
)
|
||||
except (ResourceNotFoundError, CloudError):
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_FIND,
|
||||
errors=["unable to find vm instance: %s:%s" % (name, instance_id)],
|
||||
)
|
||||
|
||||
new_protection_policy = VirtualMachineScaleSetVMProtectionPolicy(
|
||||
protect_from_scale_in=protect_from_scale_in
|
||||
)
|
||||
if instance_vm.protection_policy is not None:
|
||||
new_protection_policy = instance_vm.protection_policy
|
||||
new_protection_policy.protect_from_scale_in = protect_from_scale_in
|
||||
|
||||
instance_vm.protection_policy = new_protection_policy
|
||||
|
||||
try:
|
||||
compute_client.virtual_machine_scale_set_vms.begin_update(
|
||||
resource_group, name, instance_id, instance_vm
|
||||
)
|
||||
except (ResourceNotFoundError, CloudError, HttpResponseError) as err:
|
||||
if isinstance(err, HttpResponseError):
|
||||
err_str = str(err)
|
||||
instance_not_found = (
|
||||
" is not an active Virtual Machine Scale Set VM instanceId."
|
||||
)
|
||||
if (
|
||||
instance_not_found in err_str
|
||||
and instance_vm.protection_policy.protect_from_scale_in is False
|
||||
and protect_from_scale_in
|
||||
== instance_vm.protection_policy.protect_from_scale_in
|
||||
):
|
||||
logging.info(
|
||||
"Tried to remove scale in protection on node %s but the instance no longer exists" # noqa: E501
|
||||
% instance_id
|
||||
)
|
||||
return None
|
||||
if (
|
||||
"models that may be referenced by one or more"
|
||||
+ " VMs belonging to the Virtual Machine Scale Set"
|
||||
in str(err)
|
||||
):
|
||||
logging.error(
|
||||
"unable to resize scaleset due to model error. name: %s - err:%s",
|
||||
name,
|
||||
err,
|
||||
)
|
||||
return None
|
||||
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_UPDATE,
|
||||
errors=["unable to set protection policy on: %s:%s" % (vm_id, instance_id)],
|
||||
)
|
||||
|
||||
logging.info(
|
||||
"Successfully set scale in protection on node %s to %s"
|
||||
% (vm_id, protect_from_scale_in)
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class UnableToUpdate(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def check_can_update(name: UUID) -> Any:
|
||||
vmss = get_vmss(name)
|
||||
if vmss is None:
|
||||
raise UnableToUpdate
|
||||
|
||||
if vmss.provisioning_state == "Updating":
|
||||
raise UnableToUpdate
|
||||
|
||||
return vmss
|
||||
|
||||
|
||||
@retry_on_auth_failure()
|
||||
def reimage_vmss_nodes(name: UUID, vm_ids: Set[UUID]) -> Optional[Error]:
|
||||
check_can_update(name)
|
||||
|
||||
resource_group = get_base_resource_group()
|
||||
compute_client = get_compute_client()
|
||||
|
||||
instance_ids = set()
|
||||
machine_to_id = list_instance_ids(name)
|
||||
for vm_id in vm_ids:
|
||||
if vm_id in machine_to_id:
|
||||
instance_ids.add(machine_to_id[vm_id])
|
||||
else:
|
||||
logging.info("unable to find vm_id for %s:%s", name, vm_id)
|
||||
|
||||
# Nodes that must be are 'upgraded' before the reimage. This call makes sure
|
||||
# the instance is up-to-date with the VMSS model.
|
||||
# The expectation is that these requests are queued and handled subsequently.
|
||||
# The VMSS Team confirmed this expectation and testing supports it, as well.
|
||||
if instance_ids:
|
||||
logging.info("upgrading VMSS nodes - name: %s vm_ids: %s", name, vm_id)
|
||||
compute_client.virtual_machine_scale_sets.begin_update_instances(
|
||||
resource_group,
|
||||
str(name),
|
||||
VirtualMachineScaleSetVMInstanceIDs(instance_ids=list(instance_ids)),
|
||||
)
|
||||
logging.info("reimaging VMSS nodes - name: %s vm_ids: %s", name, vm_id)
|
||||
compute_client.virtual_machine_scale_sets.begin_reimage_all(
|
||||
resource_group,
|
||||
str(name),
|
||||
VirtualMachineScaleSetVMInstanceIDs(instance_ids=list(instance_ids)),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@retry_on_auth_failure()
|
||||
def delete_vmss_nodes(name: UUID, vm_ids: Set[UUID]) -> Optional[Error]:
|
||||
check_can_update(name)
|
||||
|
||||
resource_group = get_base_resource_group()
|
||||
logging.info("deleting scaleset VM - name: %s vm_ids:%s", name, vm_ids)
|
||||
compute_client = get_compute_client()
|
||||
|
||||
instance_ids = set()
|
||||
machine_to_id = list_instance_ids(name)
|
||||
for vm_id in vm_ids:
|
||||
if vm_id in machine_to_id:
|
||||
instance_ids.add(machine_to_id[vm_id])
|
||||
else:
|
||||
logging.info("unable to find vm_id for %s:%s", name, vm_id)
|
||||
|
||||
if instance_ids:
|
||||
compute_client.virtual_machine_scale_sets.begin_delete_instances(
|
||||
resource_group,
|
||||
str(name),
|
||||
VirtualMachineScaleSetVMInstanceRequiredIDs(
|
||||
instance_ids=list(instance_ids)
|
||||
),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@retry_on_auth_failure()
|
||||
def update_extensions(name: UUID, extensions: List[Any]) -> None:
|
||||
check_can_update(name)
|
||||
|
||||
resource_group = get_base_resource_group()
|
||||
logging.info("updating VM extensions: %s", name)
|
||||
compute_client = get_compute_client()
|
||||
try:
|
||||
compute_client.virtual_machine_scale_sets.begin_update(
|
||||
resource_group,
|
||||
str(name),
|
||||
{
|
||||
"virtual_machine_profile": {
|
||||
"extension_profile": {"extensions": extensions}
|
||||
}
|
||||
},
|
||||
)
|
||||
logging.info("VM extensions updated: %s", name)
|
||||
except HttpResponseError as err:
|
||||
if (
|
||||
"models that may be referenced by one or more"
|
||||
+ " VMs belonging to the Virtual Machine Scale Set"
|
||||
in str(err)
|
||||
):
|
||||
logging.error(
|
||||
"unable to resize scaleset due to model error. name: %s - err:%s",
|
||||
name,
|
||||
err,
|
||||
)
|
||||
|
||||
|
||||
@retry_on_auth_failure()
|
||||
def create_vmss(
|
||||
location: Region,
|
||||
name: UUID,
|
||||
vm_sku: str,
|
||||
vm_count: int,
|
||||
image: str,
|
||||
network_id: str,
|
||||
spot_instances: bool,
|
||||
ephemeral_os_disks: bool,
|
||||
extensions: List[Any],
|
||||
password: str,
|
||||
ssh_public_key: str,
|
||||
tags: Dict[str, str],
|
||||
) -> Optional[Error]:
|
||||
|
||||
vmss = get_vmss(name)
|
||||
if vmss is not None:
|
||||
return None
|
||||
|
||||
logging.info(
|
||||
"creating VM "
|
||||
"name: %s vm_sku: %s vm_count: %d "
|
||||
"image: %s subnet: %s spot_instances: %s",
|
||||
name,
|
||||
vm_sku,
|
||||
vm_count,
|
||||
image,
|
||||
network_id,
|
||||
spot_instances,
|
||||
)
|
||||
|
||||
resource_group = get_base_resource_group()
|
||||
|
||||
compute_client = get_compute_client()
|
||||
|
||||
if image.startswith("/"):
|
||||
image_ref = {"id": image}
|
||||
else:
|
||||
image_val = image.split(":", 4)
|
||||
image_ref = {
|
||||
"publisher": image_val[0],
|
||||
"offer": image_val[1],
|
||||
"sku": image_val[2],
|
||||
"version": image_val[3],
|
||||
}
|
||||
|
||||
sku = {"name": vm_sku, "tier": "Standard", "capacity": vm_count}
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
"location": location,
|
||||
"do_not_run_extensions_on_overprovisioned_vms": True,
|
||||
"upgrade_policy": {"mode": "Manual"},
|
||||
"sku": sku,
|
||||
"overprovision": False,
|
||||
"identity": {
|
||||
"type": "userAssigned",
|
||||
"userAssignedIdentities": {get_scaleset_identity_resource_path(): {}},
|
||||
},
|
||||
"virtual_machine_profile": {
|
||||
"priority": "Regular",
|
||||
"storage_profile": {
|
||||
"image_reference": image_ref,
|
||||
},
|
||||
"os_profile": {
|
||||
"computer_name_prefix": "node",
|
||||
"admin_username": "onefuzz",
|
||||
},
|
||||
"network_profile": {
|
||||
"network_interface_configurations": [
|
||||
{
|
||||
"name": "onefuzz-nic",
|
||||
"primary": True,
|
||||
"ip_configurations": [
|
||||
{"name": "onefuzz-ip-config", "subnet": {"id": network_id}}
|
||||
],
|
||||
}
|
||||
]
|
||||
},
|
||||
"extension_profile": {"extensions": extensions},
|
||||
},
|
||||
"single_placement_group": False,
|
||||
}
|
||||
|
||||
image_os = get_os(location, image)
|
||||
if isinstance(image_os, Error):
|
||||
return image_os
|
||||
|
||||
if image_os == OS.windows:
|
||||
params["virtual_machine_profile"]["os_profile"]["admin_password"] = password
|
||||
|
||||
if image_os == OS.linux:
|
||||
params["virtual_machine_profile"]["os_profile"]["linux_configuration"] = {
|
||||
"disable_password_authentication": True,
|
||||
"ssh": {
|
||||
"public_keys": [
|
||||
{
|
||||
"path": "/home/onefuzz/.ssh/authorized_keys",
|
||||
"key_data": ssh_public_key,
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
if ephemeral_os_disks:
|
||||
params["virtual_machine_profile"]["storage_profile"]["os_disk"] = {
|
||||
"diffDiskSettings": {"option": "Local"},
|
||||
"caching": "ReadOnly",
|
||||
"createOption": "FromImage",
|
||||
}
|
||||
|
||||
if spot_instances:
|
||||
# Setting max price to -1 means it won't be evicted because of
|
||||
# price.
|
||||
#
|
||||
# https://docs.microsoft.com/en-us/azure/
|
||||
# virtual-machine-scale-sets/use-spot#resource-manager-templates
|
||||
params["virtual_machine_profile"].update(
|
||||
{
|
||||
"eviction_policy": "Delete",
|
||||
"priority": "Spot",
|
||||
"billing_profile": {"max_price": -1},
|
||||
}
|
||||
)
|
||||
|
||||
params["tags"] = tags.copy()
|
||||
|
||||
owner = os.environ.get("ONEFUZZ_OWNER")
|
||||
if owner:
|
||||
params["tags"]["OWNER"] = owner
|
||||
|
||||
try:
|
||||
compute_client.virtual_machine_scale_sets.begin_create_or_update(
|
||||
resource_group, name, params
|
||||
)
|
||||
except ResourceExistsError as err:
|
||||
err_str = str(err)
|
||||
if "SkuNotAvailable" in err_str or "OperationNotAllowed" in err_str:
|
||||
return Error(
|
||||
code=ErrorCode.VM_CREATE_FAILED, errors=[f"creating vmss: {err_str}"]
|
||||
)
|
||||
raise err
|
||||
except (ResourceNotFoundError, CloudError) as err:
|
||||
if "The request failed due to conflict with a concurrent request" in repr(err):
|
||||
logging.debug(
|
||||
"create VM had conflicts with concurrent request, ignoring %s", err
|
||||
)
|
||||
return None
|
||||
return Error(
|
||||
code=ErrorCode.VM_CREATE_FAILED,
|
||||
errors=["creating vmss: %s" % err],
|
||||
)
|
||||
except HttpResponseError as err:
|
||||
err_str = str(err)
|
||||
# Catch Gen2 Hypervisor / Image mismatch errors
|
||||
# See https://github.com/microsoft/lisa/pull/716 for an example
|
||||
if (
|
||||
"check that the Hypervisor Generation of the Image matches the "
|
||||
"Hypervisor Generation of the selected VM Size"
|
||||
) in err_str:
|
||||
return Error(
|
||||
code=ErrorCode.VM_CREATE_FAILED, errors=[f"creating vmss: {err_str}"]
|
||||
)
|
||||
raise err
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
@retry_on_auth_failure()
|
||||
def list_available_skus(region: Region) -> List[str]:
|
||||
compute_client = get_compute_client()
|
||||
|
||||
skus: List[ResourceSku] = list(
|
||||
compute_client.resource_skus.list(filter="location eq '%s'" % region)
|
||||
)
|
||||
sku_names: List[str] = []
|
||||
for sku in skus:
|
||||
available = True
|
||||
if sku.restrictions is not None:
|
||||
for restriction in sku.restrictions:
|
||||
if restriction.type == ResourceSkuRestrictionsType.location and (
|
||||
region.upper() in [v.upper() for v in restriction.values]
|
||||
):
|
||||
available = False
|
||||
break
|
||||
|
||||
if available:
|
||||
sku_names.append(sku.name)
|
||||
return sku_names
|
@ -1,34 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from onefuzztypes.events import EventInstanceConfigUpdated
|
||||
from onefuzztypes.models import InstanceConfig as BASE_CONFIG
|
||||
from pydantic import Field
|
||||
|
||||
from .azure.creds import get_instance_name
|
||||
from .events import send_event
|
||||
from .orm import ORMMixin
|
||||
|
||||
|
||||
class InstanceConfig(BASE_CONFIG, ORMMixin):
|
||||
instance_name: str = Field(default_factory=get_instance_name)
|
||||
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, Optional[str]]:
|
||||
return ("instance_name", None)
|
||||
|
||||
@classmethod
|
||||
def fetch(cls) -> "InstanceConfig":
|
||||
entry = cls.get(get_instance_name())
|
||||
if entry is None:
|
||||
entry = cls(allowed_aad_tenants=[])
|
||||
entry.save()
|
||||
return entry
|
||||
|
||||
def save(self, new: bool = False, require_etag: bool = False) -> None:
|
||||
super().save(new=new, require_etag=require_etag)
|
||||
send_event(EventInstanceConfigUpdated(config=self))
|
@ -1,204 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import urllib
|
||||
from typing import Callable, Optional
|
||||
from uuid import UUID
|
||||
|
||||
import azure.functions as func
|
||||
from azure.functions import HttpRequest
|
||||
from memoization import cached
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error, UserInfo
|
||||
|
||||
from .azure.creds import get_scaleset_principal_id
|
||||
from .azure.group_membership import create_group_membership_checker
|
||||
from .config import InstanceConfig
|
||||
from .request import not_ok
|
||||
from .request_access import RequestAccess
|
||||
from .user_credentials import parse_jwt_token
|
||||
from .workers.pools import Pool
|
||||
from .workers.scalesets import Scaleset
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_rules() -> Optional[RequestAccess]:
|
||||
config = InstanceConfig.fetch()
|
||||
if config.api_access_rules:
|
||||
return RequestAccess.build(config.api_access_rules)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def check_access(req: HttpRequest) -> Optional[Error]:
|
||||
rules = get_rules()
|
||||
|
||||
# Nothing to enforce if there are no rules.
|
||||
if not rules:
|
||||
return None
|
||||
|
||||
path = urllib.parse.urlparse(req.url).path
|
||||
rule = rules.get_matching_rules(req.method, path)
|
||||
|
||||
# No restriction defined on this endpoint.
|
||||
if not rule:
|
||||
return None
|
||||
|
||||
member_id = UUID(req.headers["x-ms-client-principal-id"])
|
||||
|
||||
try:
|
||||
membership_checker = create_group_membership_checker()
|
||||
allowed = membership_checker.is_member(rule.allowed_groups_ids, member_id)
|
||||
if not allowed:
|
||||
logging.error(
|
||||
"unauthorized access: %s is not authorized to access in %s",
|
||||
member_id,
|
||||
req.url,
|
||||
)
|
||||
return Error(
|
||||
code=ErrorCode.UNAUTHORIZED,
|
||||
errors=["not approved to use this endpoint"],
|
||||
)
|
||||
except Exception as e:
|
||||
return Error(
|
||||
code=ErrorCode.UNAUTHORIZED,
|
||||
errors=["unable to interact with graph", str(e)],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def is_agent(token_data: UserInfo) -> bool:
|
||||
|
||||
if token_data.object_id:
|
||||
# backward compatibility case for scalesets deployed before the migration
|
||||
# to user assigned managed id
|
||||
scalesets = Scaleset.get_by_object_id(token_data.object_id)
|
||||
if len(scalesets) > 0:
|
||||
return True
|
||||
|
||||
# verify object_id against the user assigned managed identity
|
||||
principal_id: UUID = get_scaleset_principal_id()
|
||||
return principal_id == token_data.object_id
|
||||
|
||||
if not token_data.application_id:
|
||||
return False
|
||||
|
||||
pools = Pool.search(query={"client_id": [token_data.application_id]})
|
||||
if len(pools) > 0:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def can_modify_config_impl(config: InstanceConfig, user_info: UserInfo) -> bool:
|
||||
if config.admins is None:
|
||||
return True
|
||||
|
||||
return user_info.object_id in config.admins
|
||||
|
||||
|
||||
def can_modify_config(req: func.HttpRequest, config: InstanceConfig) -> bool:
|
||||
user_info = parse_jwt_token(req)
|
||||
if not isinstance(user_info, UserInfo):
|
||||
return False
|
||||
|
||||
return can_modify_config_impl(config, user_info)
|
||||
|
||||
|
||||
def check_require_admins_impl(
|
||||
config: InstanceConfig, user_info: UserInfo
|
||||
) -> Optional[Error]:
|
||||
if not config.require_admin_privileges:
|
||||
return None
|
||||
|
||||
if config.admins is None:
|
||||
return Error(code=ErrorCode.UNAUTHORIZED, errors=["pool modification disabled"])
|
||||
|
||||
if user_info.object_id in config.admins:
|
||||
return None
|
||||
|
||||
return Error(code=ErrorCode.UNAUTHORIZED, errors=["not authorized to manage pools"])
|
||||
|
||||
|
||||
def check_require_admins(req: func.HttpRequest) -> Optional[Error]:
|
||||
user_info = parse_jwt_token(req)
|
||||
if isinstance(user_info, Error):
|
||||
return user_info
|
||||
|
||||
# When there are no admins in the `admins` list, all users are considered
|
||||
# admins. However, `require_admin_privileges` is still useful to protect from
|
||||
# mistakes.
|
||||
#
|
||||
# To make changes while still protecting against accidental changes to
|
||||
# pools, do the following:
|
||||
#
|
||||
# 1. set `require_admin_privileges` to `False`
|
||||
# 2. make the change
|
||||
# 3. set `require_admin_privileges` to `True`
|
||||
|
||||
config = InstanceConfig.fetch()
|
||||
|
||||
return check_require_admins_impl(config, user_info)
|
||||
|
||||
|
||||
def is_user(token_data: UserInfo) -> bool:
|
||||
return not is_agent(token_data)
|
||||
|
||||
|
||||
def reject(req: func.HttpRequest, token: UserInfo) -> func.HttpResponse:
|
||||
logging.error(
|
||||
"reject token. url:%s token:%s body:%s",
|
||||
repr(req.url),
|
||||
repr(token),
|
||||
repr(req.get_body()),
|
||||
)
|
||||
return not_ok(
|
||||
Error(code=ErrorCode.UNAUTHORIZED, errors=["Unrecognized agent"]),
|
||||
status_code=401,
|
||||
context="token verification",
|
||||
)
|
||||
|
||||
|
||||
def call_if(
|
||||
req: func.HttpRequest,
|
||||
method: Callable[[func.HttpRequest], func.HttpResponse],
|
||||
*,
|
||||
allow_user: bool = False,
|
||||
allow_agent: bool = False
|
||||
) -> func.HttpResponse:
|
||||
|
||||
token = parse_jwt_token(req)
|
||||
if isinstance(token, Error):
|
||||
return not_ok(token, status_code=401, context="token verification")
|
||||
|
||||
if is_user(token):
|
||||
if not allow_user:
|
||||
return reject(req, token)
|
||||
|
||||
access = check_access(req)
|
||||
if isinstance(access, Error):
|
||||
return not_ok(access, status_code=401, context="access control")
|
||||
|
||||
if is_agent(token) and not allow_agent:
|
||||
return reject(req, token)
|
||||
|
||||
return method(req)
|
||||
|
||||
|
||||
def call_if_user(
|
||||
req: func.HttpRequest, method: Callable[[func.HttpRequest], func.HttpResponse]
|
||||
) -> func.HttpResponse:
|
||||
|
||||
return call_if(req, method, allow_user=True)
|
||||
|
||||
|
||||
def call_if_agent(
|
||||
req: func.HttpRequest, method: Callable[[func.HttpRequest], func.HttpResponse]
|
||||
) -> func.HttpResponse:
|
||||
|
||||
return call_if(req, method, allow_agent=True)
|
@ -1,81 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from onefuzztypes.events import Event, EventMessage, EventType, get_event_type
|
||||
from onefuzztypes.models import UserInfo
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .azure.creds import get_instance_id, get_instance_name
|
||||
from .azure.queue import send_message
|
||||
from .azure.storage import StorageType
|
||||
from .webhooks import Webhook
|
||||
|
||||
|
||||
class SignalREvent(BaseModel):
|
||||
target: str
|
||||
arguments: List[EventMessage]
|
||||
|
||||
|
||||
def queue_signalr_event(event_message: EventMessage) -> None:
|
||||
message = SignalREvent(target="events", arguments=[event_message]).json().encode()
|
||||
send_message("signalr-events", message, StorageType.config)
|
||||
|
||||
|
||||
def log_event(event: Event, event_type: EventType) -> None:
|
||||
scrubbed_event = filter_event(event)
|
||||
logging.info(
|
||||
"sending event: %s - %s", event_type, scrubbed_event.json(exclude_none=True)
|
||||
)
|
||||
|
||||
|
||||
def filter_event(event: Event) -> BaseModel:
|
||||
clone_event = event.copy(deep=True)
|
||||
filtered_event = filter_event_recurse(clone_event)
|
||||
return filtered_event
|
||||
|
||||
|
||||
def filter_event_recurse(entry: BaseModel) -> BaseModel:
|
||||
|
||||
for field in entry.__fields__:
|
||||
field_data = getattr(entry, field)
|
||||
|
||||
if isinstance(field_data, UserInfo):
|
||||
field_data = None
|
||||
elif isinstance(field_data, list):
|
||||
for (i, value) in enumerate(field_data):
|
||||
if isinstance(value, BaseModel):
|
||||
field_data[i] = filter_event_recurse(value)
|
||||
elif isinstance(field_data, dict):
|
||||
for (key, value) in field_data.items():
|
||||
if isinstance(value, BaseModel):
|
||||
field_data[key] = filter_event_recurse(value)
|
||||
elif isinstance(field_data, BaseModel):
|
||||
field_data = filter_event_recurse(field_data)
|
||||
|
||||
setattr(entry, field, field_data)
|
||||
|
||||
return entry
|
||||
|
||||
|
||||
def send_event(event: Event) -> None:
|
||||
event_type = get_event_type(event)
|
||||
|
||||
event_message = EventMessage(
|
||||
event_type=event_type,
|
||||
event=event.copy(deep=True),
|
||||
instance_id=get_instance_id(),
|
||||
instance_name=get_instance_name(),
|
||||
)
|
||||
|
||||
# work around odd bug with Event Message creation. See PR 939
|
||||
if event_message.event != event:
|
||||
event_message.event = event.copy(deep=True)
|
||||
|
||||
log_event(event, event_type)
|
||||
queue_signalr_event(event_message)
|
||||
Webhook.send_event(event_message)
|
@ -1,515 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from onefuzztypes.enums import OS, AgentMode
|
||||
from onefuzztypes.models import (
|
||||
AgentConfig,
|
||||
AzureMonitorExtensionConfig,
|
||||
KeyvaultExtensionConfig,
|
||||
Pool,
|
||||
ReproConfig,
|
||||
Scaleset,
|
||||
)
|
||||
from onefuzztypes.primitives import Container, Extension, Region
|
||||
|
||||
from .azure.containers import (
|
||||
get_container_sas_url,
|
||||
get_file_sas_url,
|
||||
get_file_url,
|
||||
save_blob,
|
||||
)
|
||||
from .azure.creds import get_agent_instance_url, get_instance_id
|
||||
from .azure.log_analytics import get_monitor_settings
|
||||
from .azure.queue import get_queue_sas
|
||||
from .azure.storage import StorageType
|
||||
from .config import InstanceConfig
|
||||
from .reports import get_report
|
||||
|
||||
|
||||
def generic_extensions(region: Region, vm_os: OS) -> List[Extension]:
|
||||
instance_config = InstanceConfig.fetch()
|
||||
|
||||
extensions = [monitor_extension(region, vm_os)]
|
||||
|
||||
dependency = dependency_extension(region, vm_os)
|
||||
if dependency:
|
||||
extensions.append(dependency)
|
||||
|
||||
if instance_config.extensions:
|
||||
|
||||
if instance_config.extensions.keyvault:
|
||||
keyvault = keyvault_extension(
|
||||
region, instance_config.extensions.keyvault, vm_os
|
||||
)
|
||||
extensions.append(keyvault)
|
||||
|
||||
if instance_config.extensions.geneva and vm_os == OS.windows:
|
||||
geneva = geneva_extension(region)
|
||||
extensions.append(geneva)
|
||||
|
||||
if instance_config.extensions.azure_monitor and vm_os == OS.linux:
|
||||
azmon = azmon_extension(region, instance_config.extensions.azure_monitor)
|
||||
extensions.append(azmon)
|
||||
|
||||
if instance_config.extensions.azure_security and vm_os == OS.linux:
|
||||
azsec = azsec_extension(region)
|
||||
extensions.append(azsec)
|
||||
|
||||
return extensions
|
||||
|
||||
|
||||
def monitor_extension(region: Region, vm_os: OS) -> Extension:
|
||||
settings = get_monitor_settings()
|
||||
|
||||
if vm_os == OS.windows:
|
||||
return {
|
||||
"name": "OMSExtension",
|
||||
"publisher": "Microsoft.EnterpriseCloud.Monitoring",
|
||||
"type": "MicrosoftMonitoringAgent",
|
||||
"typeHandlerVersion": "1.0",
|
||||
"location": region,
|
||||
"autoUpgradeMinorVersion": True,
|
||||
"settings": {"workspaceId": settings["id"]},
|
||||
"protectedSettings": {"workspaceKey": settings["key"]},
|
||||
}
|
||||
elif vm_os == OS.linux:
|
||||
return {
|
||||
"name": "OMSExtension",
|
||||
"publisher": "Microsoft.EnterpriseCloud.Monitoring",
|
||||
"type": "OmsAgentForLinux",
|
||||
"typeHandlerVersion": "1.12",
|
||||
"location": region,
|
||||
"autoUpgradeMinorVersion": True,
|
||||
"settings": {"workspaceId": settings["id"]},
|
||||
"protectedSettings": {"workspaceKey": settings["key"]},
|
||||
}
|
||||
raise NotImplementedError("unsupported os: %s" % vm_os)
|
||||
|
||||
|
||||
def dependency_extension(region: Region, vm_os: OS) -> Optional[Extension]:
|
||||
if vm_os == OS.windows:
|
||||
extension = {
|
||||
"name": "DependencyAgentWindows",
|
||||
"publisher": "Microsoft.Azure.Monitoring.DependencyAgent",
|
||||
"type": "DependencyAgentWindows",
|
||||
"typeHandlerVersion": "9.5",
|
||||
"location": region,
|
||||
"autoUpgradeMinorVersion": True,
|
||||
}
|
||||
return extension
|
||||
else:
|
||||
# TODO: dependency agent for linux is not reliable
|
||||
# extension = {
|
||||
# "name": "DependencyAgentLinux",
|
||||
# "publisher": "Microsoft.Azure.Monitoring.DependencyAgent",
|
||||
# "type": "DependencyAgentLinux",
|
||||
# "typeHandlerVersion": "9.5",
|
||||
# "location": vm.region,
|
||||
# "autoUpgradeMinorVersion": True,
|
||||
# }
|
||||
return None
|
||||
|
||||
|
||||
def geneva_extension(region: Region) -> Extension:
|
||||
|
||||
return {
|
||||
"name": "Microsoft.Azure.Geneva.GenevaMonitoring",
|
||||
"publisher": "Microsoft.Azure.Geneva",
|
||||
"type": "GenevaMonitoring",
|
||||
"typeHandlerVersion": "2.0",
|
||||
"location": region,
|
||||
"autoUpgradeMinorVersion": True,
|
||||
"enableAutomaticUpgrade": True,
|
||||
"settings": {},
|
||||
"protectedSettings": {},
|
||||
}
|
||||
|
||||
|
||||
def azmon_extension(
|
||||
region: Region, azure_monitor: AzureMonitorExtensionConfig
|
||||
) -> Extension:
|
||||
|
||||
auth_id = azure_monitor.monitoringGCSAuthId
|
||||
config_version = azure_monitor.config_version
|
||||
moniker = azure_monitor.moniker
|
||||
namespace = azure_monitor.namespace
|
||||
environment = azure_monitor.monitoringGSEnvironment
|
||||
account = azure_monitor.monitoringGCSAccount
|
||||
auth_id_type = azure_monitor.monitoringGCSAuthIdType
|
||||
|
||||
return {
|
||||
"name": "AzureMonitorLinuxAgent",
|
||||
"publisher": "Microsoft.Azure.Monitor",
|
||||
"location": region,
|
||||
"type": "AzureMonitorLinuxAgent",
|
||||
"typeHandlerVersion": "1.0",
|
||||
"autoUpgradeMinorVersion": True,
|
||||
"settings": {"GCS_AUTO_CONFIG": True},
|
||||
"protectedsettings": {
|
||||
"configVersion": config_version,
|
||||
"moniker": moniker,
|
||||
"namespace": namespace,
|
||||
"monitoringGCSEnvironment": environment,
|
||||
"monitoringGCSAccount": account,
|
||||
"monitoringGCSRegion": region,
|
||||
"monitoringGCSAuthId": auth_id,
|
||||
"monitoringGCSAuthIdType": auth_id_type,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def azsec_extension(region: Region) -> Extension:
|
||||
|
||||
return {
|
||||
"name": "AzureSecurityLinuxAgent",
|
||||
"publisher": "Microsoft.Azure.Security.Monitoring",
|
||||
"location": region,
|
||||
"type": "AzureSecurityLinuxAgent",
|
||||
"typeHandlerVersion": "2.0",
|
||||
"autoUpgradeMinorVersion": True,
|
||||
"settings": {"enableGenevaUpload": True, "enableAutoConfig": True},
|
||||
}
|
||||
|
||||
|
||||
def keyvault_extension(
|
||||
region: Region, keyvault: KeyvaultExtensionConfig, vm_os: OS
|
||||
) -> Extension:
|
||||
|
||||
keyvault_name = keyvault.keyvault_name
|
||||
cert_name = keyvault.cert_name
|
||||
uri = keyvault_name + cert_name
|
||||
|
||||
if vm_os == OS.windows:
|
||||
return {
|
||||
"name": "KVVMExtensionForWindows",
|
||||
"location": region,
|
||||
"publisher": "Microsoft.Azure.KeyVault",
|
||||
"type": "KeyVaultForWindows",
|
||||
"typeHandlerVersion": "1.0",
|
||||
"autoUpgradeMinorVersion": True,
|
||||
"settings": {
|
||||
"secretsManagementSettings": {
|
||||
"pollingIntervalInS": "3600",
|
||||
"certificateStoreName": "MY",
|
||||
"linkOnRenewal": False,
|
||||
"certificateStoreLocation": "LocalMachine",
|
||||
"requireInitialSync": True,
|
||||
"observedCertificates": [uri],
|
||||
}
|
||||
},
|
||||
}
|
||||
elif vm_os == OS.linux:
|
||||
cert_path = keyvault.cert_path
|
||||
extension_store = keyvault.extension_store
|
||||
cert_location = cert_path + extension_store
|
||||
return {
|
||||
"name": "KVVMExtensionForLinux",
|
||||
"location": region,
|
||||
"publisher": "Microsoft.Azure.KeyVault",
|
||||
"type": "KeyVaultForLinux",
|
||||
"typeHandlerVersion": "2.0",
|
||||
"autoUpgradeMinorVersion": True,
|
||||
"settings": {
|
||||
"secretsManagementSettings": {
|
||||
"pollingIntervalInS": "3600",
|
||||
"certificateStoreLocation": cert_location,
|
||||
"observedCertificates": [uri],
|
||||
},
|
||||
},
|
||||
}
|
||||
raise NotImplementedError("unsupported os: %s" % vm_os)
|
||||
|
||||
|
||||
def build_scaleset_script(pool: Pool, scaleset: Scaleset) -> str:
|
||||
commands = []
|
||||
extension = "ps1" if pool.os == OS.windows else "sh"
|
||||
filename = f"{scaleset.scaleset_id}/scaleset-setup.{extension}"
|
||||
sep = "\r\n" if pool.os == OS.windows else "\n"
|
||||
|
||||
if pool.os == OS.windows and scaleset.auth is not None:
|
||||
ssh_key = scaleset.auth.public_key.strip()
|
||||
ssh_path = "$env:ProgramData/ssh/administrators_authorized_keys"
|
||||
commands += [f'Set-Content -Path {ssh_path} -Value "{ssh_key}"']
|
||||
|
||||
save_blob(
|
||||
Container("vm-scripts"), filename, sep.join(commands) + sep, StorageType.config
|
||||
)
|
||||
return get_file_url(Container("vm-scripts"), filename, StorageType.config)
|
||||
|
||||
|
||||
def build_pool_config(pool: Pool) -> str:
|
||||
config = AgentConfig(
|
||||
pool_name=pool.name,
|
||||
onefuzz_url=get_agent_instance_url(),
|
||||
heartbeat_queue=get_queue_sas(
|
||||
"node-heartbeat",
|
||||
StorageType.config,
|
||||
add=True,
|
||||
),
|
||||
instance_telemetry_key=os.environ.get("APPINSIGHTS_INSTRUMENTATIONKEY"),
|
||||
microsoft_telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"),
|
||||
instance_id=get_instance_id(),
|
||||
)
|
||||
|
||||
multi_tenant_domain = os.environ.get("MULTI_TENANT_DOMAIN")
|
||||
if multi_tenant_domain:
|
||||
config.multi_tenant_domain = multi_tenant_domain
|
||||
|
||||
filename = f"{pool.name}/config.json"
|
||||
|
||||
save_blob(
|
||||
Container("vm-scripts"),
|
||||
filename,
|
||||
config.json(),
|
||||
StorageType.config,
|
||||
)
|
||||
|
||||
return config_url(Container("vm-scripts"), filename, False)
|
||||
|
||||
|
||||
def update_managed_scripts() -> None:
|
||||
commands = [
|
||||
"azcopy sync '%s' instance-specific-setup"
|
||||
% (
|
||||
get_container_sas_url(
|
||||
Container("instance-specific-setup"),
|
||||
StorageType.config,
|
||||
read=True,
|
||||
list_=True,
|
||||
)
|
||||
),
|
||||
"azcopy sync '%s' tools"
|
||||
% (
|
||||
get_container_sas_url(
|
||||
Container("tools"), StorageType.config, read=True, list_=True
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
save_blob(
|
||||
Container("vm-scripts"),
|
||||
"managed.ps1",
|
||||
"\r\n".join(commands) + "\r\n",
|
||||
StorageType.config,
|
||||
)
|
||||
save_blob(
|
||||
Container("vm-scripts"),
|
||||
"managed.sh",
|
||||
"\n".join(commands) + "\n",
|
||||
StorageType.config,
|
||||
)
|
||||
|
||||
|
||||
def config_url(container: Container, filename: str, with_sas: bool) -> str:
|
||||
if with_sas:
|
||||
return get_file_sas_url(container, filename, StorageType.config, read=True)
|
||||
else:
|
||||
return get_file_url(container, filename, StorageType.config)
|
||||
|
||||
|
||||
def agent_config(
|
||||
region: Region,
|
||||
vm_os: OS,
|
||||
mode: AgentMode,
|
||||
*,
|
||||
urls: Optional[List[str]] = None,
|
||||
with_sas: bool = False,
|
||||
) -> Extension:
|
||||
update_managed_scripts()
|
||||
|
||||
if urls is None:
|
||||
urls = []
|
||||
|
||||
if vm_os == OS.windows:
|
||||
urls += [
|
||||
config_url(Container("vm-scripts"), "managed.ps1", with_sas),
|
||||
config_url(Container("tools"), "win64/azcopy.exe", with_sas),
|
||||
config_url(
|
||||
Container("tools"),
|
||||
"win64/setup.ps1",
|
||||
with_sas,
|
||||
),
|
||||
config_url(
|
||||
Container("tools"),
|
||||
"win64/onefuzz.ps1",
|
||||
with_sas,
|
||||
),
|
||||
]
|
||||
to_execute_cmd = (
|
||||
"powershell -ExecutionPolicy Unrestricted -File win64/setup.ps1 "
|
||||
"-mode %s" % (mode.name)
|
||||
)
|
||||
extension = {
|
||||
"name": "CustomScriptExtension",
|
||||
"type": "CustomScriptExtension",
|
||||
"publisher": "Microsoft.Compute",
|
||||
"location": region,
|
||||
"force_update_tag": uuid4(),
|
||||
"type_handler_version": "1.9",
|
||||
"auto_upgrade_minor_version": True,
|
||||
"settings": {
|
||||
"commandToExecute": to_execute_cmd,
|
||||
"fileUris": urls,
|
||||
},
|
||||
"protectedSettings": {
|
||||
"managedIdentity": {},
|
||||
},
|
||||
}
|
||||
return extension
|
||||
elif vm_os == OS.linux:
|
||||
urls += [
|
||||
config_url(
|
||||
Container("vm-scripts"),
|
||||
"managed.sh",
|
||||
with_sas,
|
||||
),
|
||||
config_url(
|
||||
Container("tools"),
|
||||
"linux/azcopy",
|
||||
with_sas,
|
||||
),
|
||||
config_url(
|
||||
Container("tools"),
|
||||
"linux/setup.sh",
|
||||
with_sas,
|
||||
),
|
||||
]
|
||||
to_execute_cmd = "bash setup.sh %s" % (mode.name)
|
||||
|
||||
extension = {
|
||||
"name": "CustomScript",
|
||||
"publisher": "Microsoft.Azure.Extensions",
|
||||
"type": "CustomScript",
|
||||
"typeHandlerVersion": "2.1",
|
||||
"location": region,
|
||||
"force_update_tag": uuid4(),
|
||||
"autoUpgradeMinorVersion": True,
|
||||
"settings": {
|
||||
"commandToExecute": to_execute_cmd,
|
||||
"fileUris": urls,
|
||||
},
|
||||
"protectedSettings": {
|
||||
"managedIdentity": {},
|
||||
},
|
||||
}
|
||||
return extension
|
||||
|
||||
raise NotImplementedError("unsupported OS: %s" % vm_os)
|
||||
|
||||
|
||||
def fuzz_extensions(pool: Pool, scaleset: Scaleset) -> List[Extension]:
|
||||
urls = [build_pool_config(pool), build_scaleset_script(pool, scaleset)]
|
||||
fuzz_extension = agent_config(scaleset.region, pool.os, AgentMode.fuzz, urls=urls)
|
||||
extensions = generic_extensions(scaleset.region, pool.os)
|
||||
extensions += [fuzz_extension]
|
||||
return extensions
|
||||
|
||||
|
||||
def repro_extensions(
|
||||
region: Region,
|
||||
repro_os: OS,
|
||||
repro_id: UUID,
|
||||
repro_config: ReproConfig,
|
||||
setup_container: Optional[Container],
|
||||
) -> List[Extension]:
|
||||
# TODO - what about contents of repro.ps1 / repro.sh?
|
||||
report = get_report(repro_config.container, repro_config.path)
|
||||
if report is None:
|
||||
raise Exception("invalid report: %s" % repro_config)
|
||||
|
||||
if report.input_blob is None:
|
||||
raise Exception("unable to perform reproduction without an input blob")
|
||||
|
||||
commands = []
|
||||
if setup_container:
|
||||
commands += [
|
||||
"azcopy sync '%s' ./setup"
|
||||
% (
|
||||
get_container_sas_url(
|
||||
setup_container, StorageType.corpus, read=True, list_=True
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
urls = [
|
||||
get_file_sas_url(
|
||||
repro_config.container, repro_config.path, StorageType.corpus, read=True
|
||||
),
|
||||
get_file_sas_url(
|
||||
report.input_blob.container,
|
||||
report.input_blob.name,
|
||||
StorageType.corpus,
|
||||
read=True,
|
||||
),
|
||||
]
|
||||
|
||||
repro_files = []
|
||||
if repro_os == OS.windows:
|
||||
repro_files = ["%s/repro.ps1" % repro_id]
|
||||
task_script = "\r\n".join(commands)
|
||||
script_name = "task-setup.ps1"
|
||||
else:
|
||||
repro_files = ["%s/repro.sh" % repro_id, "%s/repro-stdout.sh" % repro_id]
|
||||
commands += ["chmod -R +x setup"]
|
||||
task_script = "\n".join(commands)
|
||||
script_name = "task-setup.sh"
|
||||
|
||||
save_blob(
|
||||
Container("task-configs"),
|
||||
"%s/%s" % (repro_id, script_name),
|
||||
task_script,
|
||||
StorageType.config,
|
||||
)
|
||||
|
||||
for repro_file in repro_files:
|
||||
urls += [
|
||||
get_file_sas_url(
|
||||
Container("repro-scripts"),
|
||||
repro_file,
|
||||
StorageType.config,
|
||||
read=True,
|
||||
),
|
||||
get_file_sas_url(
|
||||
Container("task-configs"),
|
||||
"%s/%s" % (repro_id, script_name),
|
||||
StorageType.config,
|
||||
read=True,
|
||||
),
|
||||
]
|
||||
|
||||
base_extension = agent_config(
|
||||
region, repro_os, AgentMode.repro, urls=urls, with_sas=True
|
||||
)
|
||||
extensions = generic_extensions(region, repro_os)
|
||||
extensions += [base_extension]
|
||||
return extensions
|
||||
|
||||
|
||||
def proxy_manager_extensions(region: Region, proxy_id: UUID) -> List[Extension]:
|
||||
urls = [
|
||||
get_file_sas_url(
|
||||
Container("proxy-configs"),
|
||||
"%s/%s/config.json" % (region, proxy_id),
|
||||
StorageType.config,
|
||||
read=True,
|
||||
),
|
||||
get_file_sas_url(
|
||||
Container("tools"),
|
||||
"linux/onefuzz-proxy-manager",
|
||||
StorageType.config,
|
||||
read=True,
|
||||
),
|
||||
]
|
||||
|
||||
base_extension = agent_config(
|
||||
region, OS.linux, AgentMode.proxy, urls=urls, with_sas=True
|
||||
)
|
||||
extensions = generic_extensions(region, OS.linux)
|
||||
extensions += [base_extension]
|
||||
return extensions
|
@ -1,14 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .afl import afl_linux, afl_windows
|
||||
from .libfuzzer import libfuzzer_linux, libfuzzer_windows
|
||||
|
||||
TEMPLATES = {
|
||||
"afl_windows": afl_windows,
|
||||
"afl_linux": afl_linux,
|
||||
"libfuzzer_linux": libfuzzer_linux,
|
||||
"libfuzzer_windows": libfuzzer_windows,
|
||||
}
|
@ -1,271 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from onefuzztypes.enums import (
|
||||
OS,
|
||||
ContainerType,
|
||||
TaskType,
|
||||
UserFieldOperation,
|
||||
UserFieldType,
|
||||
)
|
||||
from onefuzztypes.job_templates import JobTemplate, UserField, UserFieldLocation
|
||||
from onefuzztypes.models import (
|
||||
JobConfig,
|
||||
TaskConfig,
|
||||
TaskContainers,
|
||||
TaskDetails,
|
||||
TaskPool,
|
||||
)
|
||||
from onefuzztypes.primitives import Container, PoolName
|
||||
|
||||
from .common import (
|
||||
DURATION_HELP,
|
||||
POOL_HELP,
|
||||
REBOOT_HELP,
|
||||
RETRY_COUNT_HELP,
|
||||
TAGS_HELP,
|
||||
TARGET_EXE_HELP,
|
||||
TARGET_OPTIONS_HELP,
|
||||
VM_COUNT_HELP,
|
||||
)
|
||||
|
||||
afl_linux = JobTemplate(
|
||||
os=OS.linux,
|
||||
job=JobConfig(project="", name=Container(""), build="", duration=1),
|
||||
tasks=[
|
||||
TaskConfig(
|
||||
job_id=(UUID(int=0)),
|
||||
task=TaskDetails(
|
||||
type=TaskType.generic_supervisor,
|
||||
duration=1,
|
||||
target_exe="fuzz.exe",
|
||||
target_env={},
|
||||
target_options=[],
|
||||
supervisor_exe="",
|
||||
supervisor_options=[],
|
||||
supervisor_input_marker="@@",
|
||||
),
|
||||
pool=TaskPool(count=1, pool_name=PoolName("")),
|
||||
containers=[
|
||||
TaskContainers(
|
||||
name=Container("afl-container-name"), type=ContainerType.tools
|
||||
),
|
||||
TaskContainers(name=Container(""), type=ContainerType.setup),
|
||||
TaskContainers(name=Container(""), type=ContainerType.crashes),
|
||||
TaskContainers(name=Container(""), type=ContainerType.inputs),
|
||||
],
|
||||
tags={},
|
||||
),
|
||||
TaskConfig(
|
||||
job_id=UUID(int=0),
|
||||
prereq_tasks=[UUID(int=0)],
|
||||
task=TaskDetails(
|
||||
type=TaskType.generic_crash_report,
|
||||
duration=1,
|
||||
target_exe="fuzz.exe",
|
||||
target_env={},
|
||||
target_options=[],
|
||||
check_debugger=True,
|
||||
),
|
||||
pool=TaskPool(count=1, pool_name=PoolName("")),
|
||||
containers=[
|
||||
TaskContainers(name=Container(""), type=ContainerType.setup),
|
||||
TaskContainers(name=Container(""), type=ContainerType.crashes),
|
||||
TaskContainers(name=Container(""), type=ContainerType.no_repro),
|
||||
TaskContainers(name=Container(""), type=ContainerType.reports),
|
||||
TaskContainers(name=Container(""), type=ContainerType.unique_reports),
|
||||
],
|
||||
tags={},
|
||||
),
|
||||
],
|
||||
notifications=[],
|
||||
user_fields=[
|
||||
UserField(
|
||||
name="pool_name",
|
||||
help=POOL_HELP,
|
||||
type=UserFieldType.Str,
|
||||
required=True,
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/0/pool/pool_name",
|
||||
),
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/1/pool/pool_name",
|
||||
),
|
||||
],
|
||||
),
|
||||
UserField(
|
||||
name="duration",
|
||||
help=DURATION_HELP,
|
||||
type=UserFieldType.Int,
|
||||
default=24,
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/0/task/duration",
|
||||
),
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/1/task/duration",
|
||||
),
|
||||
UserFieldLocation(op=UserFieldOperation.replace, path="/job/duration"),
|
||||
],
|
||||
),
|
||||
UserField(
|
||||
name="target_exe",
|
||||
help=TARGET_EXE_HELP,
|
||||
type=UserFieldType.Str,
|
||||
default="fuzz.exe",
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/0/task/target_exe",
|
||||
),
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/1/task/target_exe",
|
||||
),
|
||||
],
|
||||
),
|
||||
UserField(
|
||||
name="target_options",
|
||||
help=TARGET_OPTIONS_HELP,
|
||||
type=UserFieldType.ListStr,
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/0/task/target_options",
|
||||
),
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/1/task/target_options",
|
||||
),
|
||||
],
|
||||
),
|
||||
UserField(
|
||||
name="supervisor_exe",
|
||||
help="Path to the AFL executable",
|
||||
type=UserFieldType.Str,
|
||||
default="{tools_dir}/afl-fuzz",
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/0/task/supervisor_exe",
|
||||
),
|
||||
],
|
||||
),
|
||||
UserField(
|
||||
name="supervisor_options",
|
||||
help="AFL command line options",
|
||||
type=UserFieldType.ListStr,
|
||||
default=[
|
||||
"-d",
|
||||
"-i",
|
||||
"{input_corpus}",
|
||||
"-o",
|
||||
"{runtime_dir}",
|
||||
"--",
|
||||
"{target_exe}",
|
||||
"{target_options}",
|
||||
],
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/0/task/supervisor_options",
|
||||
),
|
||||
],
|
||||
),
|
||||
UserField(
|
||||
name="supervisor_env",
|
||||
help="Enviornment variables for AFL",
|
||||
type=UserFieldType.DictStr,
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/0/task/supervisor_env",
|
||||
),
|
||||
],
|
||||
),
|
||||
UserField(
|
||||
name="vm_count",
|
||||
help=VM_COUNT_HELP,
|
||||
type=UserFieldType.Int,
|
||||
default=2,
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/0/pool/count",
|
||||
),
|
||||
],
|
||||
),
|
||||
UserField(
|
||||
name="check_retry_count",
|
||||
help=RETRY_COUNT_HELP,
|
||||
type=UserFieldType.Int,
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/1/task/check_retry_count",
|
||||
),
|
||||
],
|
||||
),
|
||||
UserField(
|
||||
name="afl_container",
|
||||
help=(
|
||||
"Name of the AFL storage container (use "
|
||||
"this to specify alternate builds of AFL)"
|
||||
),
|
||||
type=UserFieldType.Str,
|
||||
default="afl-linux",
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/0/containers/0/name",
|
||||
),
|
||||
],
|
||||
),
|
||||
UserField(
|
||||
name="reboot_after_setup",
|
||||
help=REBOOT_HELP,
|
||||
type=UserFieldType.Bool,
|
||||
default=False,
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/0/task/reboot_after_setup",
|
||||
),
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/1/task/reboot_after_setup",
|
||||
),
|
||||
],
|
||||
),
|
||||
UserField(
|
||||
name="tags",
|
||||
help=TAGS_HELP,
|
||||
type=UserFieldType.DictStr,
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.add,
|
||||
path="/tasks/0/tags",
|
||||
),
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.add,
|
||||
path="/tasks/1/tags",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
afl_windows = afl_linux.copy(deep=True)
|
||||
afl_windows.os = OS.windows
|
||||
for user_field in afl_windows.user_fields:
|
||||
if user_field.name == "afl_container":
|
||||
user_field.default = "afl-windows"
|
@ -1,14 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
POOL_HELP = "Execute the task on the specified pool"
|
||||
DURATION_HELP = "Number of hours to execute the task"
|
||||
TARGET_EXE_HELP = "Path to the target executable"
|
||||
TARGET_OPTIONS_HELP = "Command line options for the target"
|
||||
VM_COUNT_HELP = "Number of VMs to use for fuzzing"
|
||||
RETRY_COUNT_HELP = "Number of times to retry a crash to verify reproducability"
|
||||
REBOOT_HELP = "After executing the setup script, reboot the VM"
|
||||
TAGS_HELP = "User provided metadata for the tasks"
|
@ -1,348 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from onefuzztypes.enums import (
|
||||
OS,
|
||||
ContainerType,
|
||||
TaskType,
|
||||
UserFieldOperation,
|
||||
UserFieldType,
|
||||
)
|
||||
from onefuzztypes.job_templates import JobTemplate, UserField, UserFieldLocation
|
||||
from onefuzztypes.models import (
|
||||
JobConfig,
|
||||
TaskConfig,
|
||||
TaskContainers,
|
||||
TaskDetails,
|
||||
TaskPool,
|
||||
)
|
||||
from onefuzztypes.primitives import Container, PoolName
|
||||
|
||||
from .common import (
|
||||
DURATION_HELP,
|
||||
POOL_HELP,
|
||||
REBOOT_HELP,
|
||||
RETRY_COUNT_HELP,
|
||||
TAGS_HELP,
|
||||
TARGET_EXE_HELP,
|
||||
TARGET_OPTIONS_HELP,
|
||||
VM_COUNT_HELP,
|
||||
)
|
||||
|
||||
libfuzzer_linux = JobTemplate(
|
||||
os=OS.linux,
|
||||
job=JobConfig(project="", name=Container(""), build="", duration=1),
|
||||
tasks=[
|
||||
TaskConfig(
|
||||
job_id=UUID(int=0),
|
||||
task=TaskDetails(
|
||||
type=TaskType.libfuzzer_fuzz,
|
||||
duration=1,
|
||||
target_exe="fuzz.exe",
|
||||
target_env={},
|
||||
target_options=[],
|
||||
),
|
||||
pool=TaskPool(count=1, pool_name=PoolName("")),
|
||||
containers=[
|
||||
TaskContainers(name=Container(""), type=ContainerType.setup),
|
||||
TaskContainers(name=Container(""), type=ContainerType.crashes),
|
||||
TaskContainers(name=Container(""), type=ContainerType.inputs),
|
||||
],
|
||||
tags={},
|
||||
),
|
||||
TaskConfig(
|
||||
job_id=UUID(int=0),
|
||||
prereq_tasks=[UUID(int=0)],
|
||||
task=TaskDetails(
|
||||
type=TaskType.libfuzzer_crash_report,
|
||||
duration=1,
|
||||
target_exe="fuzz.exe",
|
||||
target_env={},
|
||||
target_options=[],
|
||||
),
|
||||
pool=TaskPool(count=1, pool_name=PoolName("")),
|
||||
containers=[
|
||||
TaskContainers(name=Container(""), type=ContainerType.setup),
|
||||
TaskContainers(name=Container(""), type=ContainerType.crashes),
|
||||
TaskContainers(name=Container(""), type=ContainerType.no_repro),
|
||||
TaskContainers(name=Container(""), type=ContainerType.reports),
|
||||
TaskContainers(name=Container(""), type=ContainerType.unique_reports),
|
||||
],
|
||||
tags={},
|
||||
),
|
||||
TaskConfig(
|
||||
job_id=UUID(int=0),
|
||||
prereq_tasks=[UUID(int=0)],
|
||||
task=TaskDetails(
|
||||
type=TaskType.coverage,
|
||||
duration=1,
|
||||
target_exe="fuzz.exe",
|
||||
target_env={},
|
||||
target_options=[],
|
||||
),
|
||||
pool=TaskPool(count=1, pool_name=PoolName("")),
|
||||
containers=[
|
||||
TaskContainers(name=Container(""), type=ContainerType.setup),
|
||||
TaskContainers(name=Container(""), type=ContainerType.readonly_inputs),
|
||||
TaskContainers(name=Container(""), type=ContainerType.coverage),
|
||||
],
|
||||
tags={},
|
||||
),
|
||||
],
|
||||
notifications=[],
|
||||
user_fields=[
|
||||
UserField(
|
||||
name="pool_name",
|
||||
help=POOL_HELP,
|
||||
type=UserFieldType.Str,
|
||||
required=True,
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/0/pool/pool_name",
|
||||
),
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/1/pool/pool_name",
|
||||
),
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/2/pool/pool_name",
|
||||
),
|
||||
],
|
||||
),
|
||||
UserField(
|
||||
name="target_exe",
|
||||
help=TARGET_EXE_HELP,
|
||||
type=UserFieldType.Str,
|
||||
default="fuzz.exe",
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/0/task/target_exe",
|
||||
),
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/1/task/target_exe",
|
||||
),
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/2/task/target_exe",
|
||||
),
|
||||
],
|
||||
),
|
||||
UserField(
|
||||
name="duration",
|
||||
help=DURATION_HELP,
|
||||
type=UserFieldType.Int,
|
||||
default=24,
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/0/task/duration",
|
||||
),
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/1/task/duration",
|
||||
),
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/2/task/duration",
|
||||
),
|
||||
UserFieldLocation(op=UserFieldOperation.replace, path="/job/duration"),
|
||||
],
|
||||
),
|
||||
UserField(
|
||||
name="target_workers",
|
||||
help="Number of instances of the libfuzzer target on each VM",
|
||||
type=UserFieldType.Int,
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/0/task/target_workers",
|
||||
),
|
||||
],
|
||||
),
|
||||
UserField(
|
||||
name="vm_count",
|
||||
help=VM_COUNT_HELP,
|
||||
type=UserFieldType.Int,
|
||||
default=2,
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/0/pool/count",
|
||||
),
|
||||
],
|
||||
),
|
||||
UserField(
|
||||
name="target_options",
|
||||
help=TARGET_OPTIONS_HELP,
|
||||
type=UserFieldType.ListStr,
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/0/task/target_options",
|
||||
),
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/1/task/target_options",
|
||||
),
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/2/task/target_options",
|
||||
),
|
||||
],
|
||||
),
|
||||
UserField(
|
||||
name="target_env",
|
||||
help="Environment variables for the target",
|
||||
type=UserFieldType.DictStr,
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/0/task/target_env",
|
||||
),
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/1/task/target_env",
|
||||
),
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/2/task/target_env",
|
||||
),
|
||||
],
|
||||
),
|
||||
UserField(
|
||||
name="check_fuzzer_help",
|
||||
help="Verify fuzzer by checking if it supports -help=1",
|
||||
type=UserFieldType.Bool,
|
||||
default=True,
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.add,
|
||||
path="/tasks/0/task/check_fuzzer_help",
|
||||
),
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.add,
|
||||
path="/tasks/1/task/check_fuzzer_help",
|
||||
),
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.add,
|
||||
path="/tasks/2/task/check_fuzzer_help",
|
||||
),
|
||||
],
|
||||
),
|
||||
UserField(
|
||||
name="colocate",
|
||||
help="Run all of the tasks on the same node",
|
||||
type=UserFieldType.Bool,
|
||||
default=True,
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.add,
|
||||
path="/tasks/0/colocate",
|
||||
),
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.add,
|
||||
path="/tasks/1/colocate",
|
||||
),
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.add,
|
||||
path="/tasks/2/colocate",
|
||||
),
|
||||
],
|
||||
),
|
||||
UserField(
|
||||
name="expect_crash_on_failure",
|
||||
help="Require crashes upon non-zero exits from libfuzzer",
|
||||
type=UserFieldType.Bool,
|
||||
default=False,
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.add,
|
||||
path="/tasks/0/task/expect_crash_on_failure",
|
||||
),
|
||||
],
|
||||
),
|
||||
UserField(
|
||||
name="reboot_after_setup",
|
||||
help=REBOOT_HELP,
|
||||
type=UserFieldType.Bool,
|
||||
default=False,
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/0/task/reboot_after_setup",
|
||||
),
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/1/task/reboot_after_setup",
|
||||
),
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/2/task/reboot_after_setup",
|
||||
),
|
||||
],
|
||||
),
|
||||
UserField(
|
||||
name="check_retry_count",
|
||||
help=RETRY_COUNT_HELP,
|
||||
type=UserFieldType.Int,
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/1/task/check_retry_count",
|
||||
),
|
||||
],
|
||||
),
|
||||
UserField(
|
||||
name="target_timeout",
|
||||
help="Number of seconds to timeout during reproduction",
|
||||
type=UserFieldType.Int,
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/1/task/target_timeout",
|
||||
),
|
||||
],
|
||||
),
|
||||
UserField(
|
||||
name="minimized_stack_depth",
|
||||
help="Number of frames to include in the minimized stack",
|
||||
type=UserFieldType.Int,
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.replace,
|
||||
path="/tasks/1/task/minimized_stack_depth",
|
||||
),
|
||||
],
|
||||
),
|
||||
UserField(
|
||||
name="tags",
|
||||
help=TAGS_HELP,
|
||||
type=UserFieldType.DictStr,
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.add,
|
||||
path="/tasks/0/tags",
|
||||
),
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.add,
|
||||
path="/tasks/1/tags",
|
||||
),
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.add,
|
||||
path="/tasks/2/tags",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
libfuzzer_windows = libfuzzer_linux.copy(deep=True)
|
||||
libfuzzer_windows.os = OS.windows
|
@ -1,131 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import json
|
||||
from typing import Dict, List
|
||||
|
||||
from jsonpatch import apply_patch
|
||||
from memoization import cached
|
||||
from onefuzztypes.enums import ContainerType, ErrorCode, UserFieldType
|
||||
from onefuzztypes.job_templates import (
|
||||
TEMPLATE_BASE_FIELDS,
|
||||
JobTemplate,
|
||||
JobTemplateConfig,
|
||||
JobTemplateField,
|
||||
JobTemplateRequest,
|
||||
TemplateUserData,
|
||||
UserField,
|
||||
)
|
||||
from onefuzztypes.models import Error, Result
|
||||
|
||||
|
||||
def template_container_types(template: JobTemplate) -> List[ContainerType]:
|
||||
return list(set(c.type for t in template.tasks for c in t.containers if not c.name))
|
||||
|
||||
|
||||
@cached
|
||||
def build_input_config(name: str, template: JobTemplate) -> JobTemplateConfig:
|
||||
user_fields = [
|
||||
JobTemplateField(
|
||||
name=x.name,
|
||||
type=x.type,
|
||||
required=x.required,
|
||||
default=x.default,
|
||||
help=x.help,
|
||||
)
|
||||
for x in TEMPLATE_BASE_FIELDS + template.user_fields
|
||||
]
|
||||
containers = template_container_types(template)
|
||||
|
||||
return JobTemplateConfig(
|
||||
os=template.os,
|
||||
name=name,
|
||||
user_fields=user_fields,
|
||||
containers=containers,
|
||||
)
|
||||
|
||||
|
||||
def build_patches(
|
||||
data: TemplateUserData, field: UserField
|
||||
) -> List[Dict[str, TemplateUserData]]:
|
||||
patches = []
|
||||
|
||||
if field.type == UserFieldType.Bool and not isinstance(data, bool):
|
||||
raise Exception("invalid bool field")
|
||||
if field.type == UserFieldType.Int and not isinstance(data, int):
|
||||
raise Exception("invalid int field")
|
||||
if field.type == UserFieldType.Str and not isinstance(data, str):
|
||||
raise Exception("invalid str field")
|
||||
if field.type == UserFieldType.DictStr and not isinstance(data, dict):
|
||||
raise Exception("invalid DictStr field")
|
||||
if field.type == UserFieldType.ListStr and not isinstance(data, list):
|
||||
raise Exception("invalid ListStr field")
|
||||
|
||||
for location in field.locations:
|
||||
patches.append(
|
||||
{
|
||||
"op": location.op.name,
|
||||
"path": location.path,
|
||||
"value": data,
|
||||
}
|
||||
)
|
||||
|
||||
return patches
|
||||
|
||||
|
||||
def _fail(why: str) -> Error:
|
||||
return Error(code=ErrorCode.INVALID_REQUEST, errors=[why])
|
||||
|
||||
|
||||
def render(request: JobTemplateRequest, template: JobTemplate) -> Result[JobTemplate]:
|
||||
patches = []
|
||||
seen = set()
|
||||
|
||||
for name in request.user_fields:
|
||||
for field in TEMPLATE_BASE_FIELDS + template.user_fields:
|
||||
if field.name == name:
|
||||
if name in seen:
|
||||
return _fail(f"duplicate specification: {name}")
|
||||
|
||||
seen.add(name)
|
||||
|
||||
if name not in seen:
|
||||
return _fail(f"extra field: {name}")
|
||||
|
||||
for field in TEMPLATE_BASE_FIELDS + template.user_fields:
|
||||
if field.name not in request.user_fields:
|
||||
if field.required:
|
||||
return _fail(f"missing required field: {field.name}")
|
||||
else:
|
||||
# optional fields can be missing
|
||||
continue
|
||||
|
||||
patches += build_patches(request.user_fields[field.name], field)
|
||||
|
||||
raw = json.loads(template.json())
|
||||
updated = apply_patch(raw, patches)
|
||||
rendered = JobTemplate.parse_obj(updated)
|
||||
|
||||
used_containers = []
|
||||
for task in rendered.tasks:
|
||||
for task_container in task.containers:
|
||||
if task_container.name:
|
||||
# only need to fill out containers with names
|
||||
continue
|
||||
|
||||
for entry in request.containers:
|
||||
if entry.type != task_container.type:
|
||||
continue
|
||||
task_container.name = entry.name
|
||||
used_containers.append(entry)
|
||||
|
||||
if not task_container.name:
|
||||
return _fail(f"missing container definition {task_container.type}")
|
||||
|
||||
for entry in request.containers:
|
||||
if entry not in used_containers:
|
||||
return _fail(f"unused container in request: {entry}")
|
||||
|
||||
return rendered
|
@ -1,118 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.job_templates import JobTemplateConfig
|
||||
from onefuzztypes.job_templates import JobTemplateIndex as BASE_INDEX
|
||||
from onefuzztypes.job_templates import JobTemplateRequest
|
||||
from onefuzztypes.models import Error, Result, UserInfo
|
||||
|
||||
from ..jobs import Job
|
||||
from ..notifications.main import Notification
|
||||
from ..orm import ORMMixin
|
||||
from ..tasks.config import TaskConfigError, check_config
|
||||
from ..tasks.main import Task
|
||||
from .defaults import TEMPLATES
|
||||
from .render import build_input_config, render
|
||||
|
||||
|
||||
class JobTemplateIndex(BASE_INDEX, ORMMixin):
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, Optional[str]]:
|
||||
return ("name", None)
|
||||
|
||||
@classmethod
|
||||
def get_base_entry(cls, name: str) -> Optional[BASE_INDEX]:
|
||||
result = cls.get(name)
|
||||
if result is not None:
|
||||
return BASE_INDEX(name=name, template=result.template)
|
||||
|
||||
template = TEMPLATES.get(name)
|
||||
if template is None:
|
||||
return None
|
||||
|
||||
return BASE_INDEX(name=name, template=template)
|
||||
|
||||
@classmethod
|
||||
def get_index(cls) -> List[BASE_INDEX]:
|
||||
entries = [BASE_INDEX(name=x.name, template=x.template) for x in cls.search()]
|
||||
|
||||
# if the local install has replaced the built-in templates, skip over them
|
||||
for name, template in TEMPLATES.items():
|
||||
if any(x.name == name for x in entries):
|
||||
continue
|
||||
entries.append(BASE_INDEX(name=name, template=template))
|
||||
|
||||
return entries
|
||||
|
||||
@classmethod
|
||||
def get_configs(cls) -> List[JobTemplateConfig]:
|
||||
configs = [build_input_config(x.name, x.template) for x in cls.get_index()]
|
||||
|
||||
return configs
|
||||
|
||||
@classmethod
|
||||
def execute(cls, request: JobTemplateRequest, user_info: UserInfo) -> Result[Job]:
|
||||
index = cls.get(request.name)
|
||||
if index is None:
|
||||
if request.name not in TEMPLATES:
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
errors=["no such template: %s" % request.name],
|
||||
)
|
||||
base_template = TEMPLATES[request.name]
|
||||
else:
|
||||
base_template = index.template
|
||||
|
||||
template = render(request, base_template)
|
||||
if isinstance(template, Error):
|
||||
return template
|
||||
|
||||
try:
|
||||
for task_config in template.tasks:
|
||||
check_config(task_config)
|
||||
if task_config.pool is None:
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST, errors=["pool not defined"]
|
||||
)
|
||||
|
||||
except TaskConfigError as err:
|
||||
return Error(code=ErrorCode.INVALID_REQUEST, errors=[str(err)])
|
||||
|
||||
for notification_config in template.notifications:
|
||||
for task_container in request.containers:
|
||||
if task_container.type == notification_config.container_type:
|
||||
notification = Notification.create(
|
||||
task_container.name,
|
||||
notification_config.notification.config,
|
||||
True,
|
||||
)
|
||||
if isinstance(notification, Error):
|
||||
return notification
|
||||
|
||||
job = Job(config=template.job)
|
||||
job.save()
|
||||
|
||||
tasks: List[Task] = []
|
||||
for task_config in template.tasks:
|
||||
task_config.job_id = job.job_id
|
||||
if task_config.prereq_tasks:
|
||||
# pydantic verifies prereq_tasks in u128 form are index refs to
|
||||
# previously generated tasks
|
||||
task_config.prereq_tasks = [
|
||||
tasks[x.int].task_id for x in task_config.prereq_tasks
|
||||
]
|
||||
|
||||
task = Task.create(
|
||||
config=task_config, job_id=job.job_id, user_info=user_info
|
||||
)
|
||||
if isinstance(task, Error):
|
||||
return task
|
||||
|
||||
tasks.append(task)
|
||||
|
||||
return job
|
@ -1,144 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from onefuzztypes.enums import ErrorCode, JobState, TaskState
|
||||
from onefuzztypes.events import EventJobCreated, EventJobStopped, JobTaskStopped
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.models import Job as BASE_JOB
|
||||
|
||||
from .events import send_event
|
||||
from .orm import MappingIntStrAny, ORMMixin, QueryFilter
|
||||
from .tasks.main import Task
|
||||
|
||||
JOB_LOG_PREFIX = "jobs: "
|
||||
JOB_NEVER_STARTED_DURATION: timedelta = timedelta(days=30)
|
||||
|
||||
|
||||
class Job(BASE_JOB, ORMMixin):
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, Optional[str]]:
|
||||
return ("job_id", None)
|
||||
|
||||
@classmethod
|
||||
def search_states(cls, *, states: Optional[List[JobState]] = None) -> List["Job"]:
|
||||
query: QueryFilter = {}
|
||||
if states:
|
||||
query["state"] = states
|
||||
return cls.search(query=query)
|
||||
|
||||
@classmethod
|
||||
def search_expired(cls) -> List["Job"]:
|
||||
time_filter = "end_time lt datetime'%s'" % datetime.utcnow().isoformat()
|
||||
|
||||
return cls.search(
|
||||
query={"state": JobState.available()}, raw_unchecked_filter=time_filter
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def stop_never_started_jobs(cls) -> None:
|
||||
# Note, the "not(end_time...)" with end_time set long before the use of
|
||||
# OneFuzz enables identifying those without end_time being set.
|
||||
last_timestamp = (datetime.utcnow() - JOB_NEVER_STARTED_DURATION).isoformat()
|
||||
|
||||
time_filter = (
|
||||
f"Timestamp lt datetime'{last_timestamp}' and "
|
||||
"not(end_time ge datetime'2000-01-11T00:00:00.0Z')"
|
||||
)
|
||||
|
||||
for job in cls.search(
|
||||
query={
|
||||
"state": [JobState.enabled],
|
||||
},
|
||||
raw_unchecked_filter=time_filter,
|
||||
):
|
||||
for task in Task.search(query={"job_id": [job.job_id]}):
|
||||
task.mark_failed(
|
||||
Error(
|
||||
code=ErrorCode.TASK_FAILED,
|
||||
errors=["job never not start"],
|
||||
)
|
||||
)
|
||||
|
||||
logging.info(
|
||||
JOB_LOG_PREFIX + "stopping job that never started: %s", job.job_id
|
||||
)
|
||||
job.stopping()
|
||||
|
||||
def save_exclude(self) -> Optional[MappingIntStrAny]:
|
||||
return {"task_info": ...}
|
||||
|
||||
def telemetry_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
"machine_id": ...,
|
||||
"state": ...,
|
||||
"scaleset_id": ...,
|
||||
}
|
||||
|
||||
def init(self) -> None:
|
||||
logging.info(JOB_LOG_PREFIX + "init: %s", self.job_id)
|
||||
self.state = JobState.enabled
|
||||
self.save()
|
||||
|
||||
def stop_if_all_done(self) -> None:
|
||||
not_stopped = [
|
||||
task
|
||||
for task in Task.search(query={"job_id": [self.job_id]})
|
||||
if task.state != TaskState.stopped
|
||||
]
|
||||
if not_stopped:
|
||||
return
|
||||
|
||||
logging.info(
|
||||
JOB_LOG_PREFIX + "stopping job as all tasks are stopped: %s", self.job_id
|
||||
)
|
||||
self.stopping()
|
||||
|
||||
def stopping(self) -> None:
|
||||
self.state = JobState.stopping
|
||||
logging.info(JOB_LOG_PREFIX + "stopping: %s", self.job_id)
|
||||
tasks = Task.search(query={"job_id": [self.job_id]})
|
||||
not_stopped = [task for task in tasks if task.state != TaskState.stopped]
|
||||
|
||||
if not_stopped:
|
||||
for task in not_stopped:
|
||||
task.mark_stopping()
|
||||
else:
|
||||
self.state = JobState.stopped
|
||||
task_info = [
|
||||
JobTaskStopped(
|
||||
task_id=x.task_id, error=x.error, task_type=x.config.task.type
|
||||
)
|
||||
for x in tasks
|
||||
]
|
||||
send_event(
|
||||
EventJobStopped(
|
||||
job_id=self.job_id,
|
||||
config=self.config,
|
||||
user_info=self.user_info,
|
||||
task_info=task_info,
|
||||
)
|
||||
)
|
||||
self.save()
|
||||
|
||||
def on_start(self) -> None:
|
||||
# try to keep this effectively idempotent
|
||||
if self.end_time is None:
|
||||
self.end_time = datetime.utcnow() + timedelta(hours=self.config.duration)
|
||||
self.save()
|
||||
|
||||
def save(self, new: bool = False, require_etag: bool = False) -> None:
|
||||
created = self.etag is None
|
||||
super().save(new=new, require_etag=require_etag)
|
||||
|
||||
if created:
|
||||
send_event(
|
||||
EventJobCreated(
|
||||
job_id=self.job_id, config=self.config, user_info=self.user_info
|
||||
)
|
||||
)
|
@ -1,294 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from typing import Iterator, List, Optional, Tuple, Union
|
||||
from uuid import UUID
|
||||
|
||||
from azure.devops.connection import Connection
|
||||
from azure.devops.credentials import BasicAuthentication
|
||||
from azure.devops.exceptions import (
|
||||
AzureDevOpsAuthenticationError,
|
||||
AzureDevOpsClientError,
|
||||
AzureDevOpsClientRequestError,
|
||||
AzureDevOpsServiceError,
|
||||
)
|
||||
from azure.devops.v6_0.work_item_tracking.models import (
|
||||
CommentCreate,
|
||||
JsonPatchOperation,
|
||||
Wiql,
|
||||
WorkItem,
|
||||
)
|
||||
from azure.devops.v6_0.work_item_tracking.work_item_tracking_client import (
|
||||
WorkItemTrackingClient,
|
||||
)
|
||||
from memoization import cached
|
||||
from onefuzztypes.models import ADOTemplate, RegressionReport, Report
|
||||
from onefuzztypes.primitives import Container
|
||||
|
||||
from ..secrets import get_secret_string_value
|
||||
from .common import Render, log_failed_notification
|
||||
|
||||
|
||||
class AdoNotificationException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_ado_client(base_url: str, token: str) -> WorkItemTrackingClient:
|
||||
connection = Connection(base_url=base_url, creds=BasicAuthentication("PAT", token))
|
||||
client = connection.clients_v6_0.get_work_item_tracking_client()
|
||||
return client
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_valid_fields(
|
||||
client: WorkItemTrackingClient, project: Optional[str] = None
|
||||
) -> List[str]:
|
||||
valid_fields = [
|
||||
x.reference_name.lower()
|
||||
for x in client.get_fields(project=project, expand="ExtensionFields")
|
||||
]
|
||||
return valid_fields
|
||||
|
||||
|
||||
class ADO:
|
||||
def __init__(
|
||||
self,
|
||||
container: Container,
|
||||
filename: str,
|
||||
config: ADOTemplate,
|
||||
report: Report,
|
||||
*,
|
||||
renderer: Optional[Render] = None,
|
||||
):
|
||||
self.config = config
|
||||
if renderer:
|
||||
self.renderer = renderer
|
||||
else:
|
||||
self.renderer = Render(container, filename, report)
|
||||
self.project = self.render(self.config.project)
|
||||
|
||||
def connect(self) -> None:
|
||||
auth_token = get_secret_string_value(self.config.auth_token)
|
||||
self.client = get_ado_client(self.config.base_url, auth_token)
|
||||
|
||||
def render(self, template: str) -> str:
|
||||
return self.renderer.render(template)
|
||||
|
||||
def existing_work_items(self) -> Iterator[WorkItem]:
|
||||
filters = {}
|
||||
for key in self.config.unique_fields:
|
||||
if key == "System.TeamProject":
|
||||
value = self.render(self.config.project)
|
||||
else:
|
||||
value = self.render(self.config.ado_fields[key])
|
||||
filters[key.lower()] = value
|
||||
|
||||
valid_fields = get_valid_fields(
|
||||
self.client, project=filters.get("system.teamproject")
|
||||
)
|
||||
|
||||
post_query_filter = {}
|
||||
|
||||
# WIQL (Work Item Query Language) is an SQL like query language that
|
||||
# doesn't support query params, safe quoting, or any other SQL-injection
|
||||
# protection mechanisms.
|
||||
#
|
||||
# As such, build the WIQL with a those fields we can pre-determine are
|
||||
# "safe" and otherwise use post-query filtering.
|
||||
parts = []
|
||||
for k, v in filters.items():
|
||||
# Only add pre-system approved fields to the query
|
||||
if k not in valid_fields:
|
||||
post_query_filter[k] = v
|
||||
continue
|
||||
|
||||
# WIQL supports wrapping values in ' or " and escaping ' by doubling it
|
||||
#
|
||||
# For this System.Title: hi'there
|
||||
# use this query fragment: [System.Title] = 'hi''there'
|
||||
#
|
||||
# For this System.Title: hi"there
|
||||
# use this query fragment: [System.Title] = 'hi"there'
|
||||
#
|
||||
# For this System.Title: hi'"there
|
||||
# use this query fragment: [System.Title] = 'hi''"there'
|
||||
SINGLE = "'"
|
||||
parts.append("[%s] = '%s'" % (k, v.replace(SINGLE, SINGLE + SINGLE)))
|
||||
|
||||
query = "select [System.Id] from WorkItems"
|
||||
if parts:
|
||||
query += " where " + " AND ".join(parts)
|
||||
|
||||
wiql = Wiql(query=query)
|
||||
for entry in self.client.query_by_wiql(wiql).work_items:
|
||||
item = self.client.get_work_item(entry.id, expand="Fields")
|
||||
lowered_fields = {x.lower(): str(y) for (x, y) in item.fields.items()}
|
||||
if post_query_filter and not all(
|
||||
[
|
||||
k.lower() in lowered_fields and lowered_fields[k.lower()] == v
|
||||
for (k, v) in post_query_filter.items()
|
||||
]
|
||||
):
|
||||
continue
|
||||
yield item
|
||||
|
||||
def update_existing(self, item: WorkItem, notification_info: str) -> None:
|
||||
if self.config.on_duplicate.comment:
|
||||
comment = self.render(self.config.on_duplicate.comment)
|
||||
self.client.add_comment(
|
||||
CommentCreate(comment),
|
||||
self.project,
|
||||
item.id,
|
||||
)
|
||||
|
||||
document = []
|
||||
for field in self.config.on_duplicate.increment:
|
||||
value = int(item.fields[field]) if field in item.fields else 0
|
||||
value += 1
|
||||
document.append(
|
||||
JsonPatchOperation(
|
||||
op="Replace", path="/fields/%s" % field, value=str(value)
|
||||
)
|
||||
)
|
||||
|
||||
for field in self.config.on_duplicate.ado_fields:
|
||||
field_value = self.render(self.config.on_duplicate.ado_fields[field])
|
||||
document.append(
|
||||
JsonPatchOperation(
|
||||
op="Replace", path="/fields/%s" % field, value=field_value
|
||||
)
|
||||
)
|
||||
|
||||
if item.fields["System.State"] in self.config.on_duplicate.set_state:
|
||||
document.append(
|
||||
JsonPatchOperation(
|
||||
op="Replace",
|
||||
path="/fields/System.State",
|
||||
value=self.config.on_duplicate.set_state[
|
||||
item.fields["System.State"]
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
if document:
|
||||
self.client.update_work_item(document, item.id, project=self.project)
|
||||
logging.info(
|
||||
f"notify ado: updated work item {item.id} - {notification_info}"
|
||||
)
|
||||
else:
|
||||
logging.info(
|
||||
f"notify ado: no update for work item {item.id} - {notification_info}"
|
||||
)
|
||||
|
||||
def render_new(self) -> Tuple[str, List[JsonPatchOperation]]:
|
||||
task_type = self.render(self.config.type)
|
||||
document = []
|
||||
if "System.Tags" not in self.config.ado_fields:
|
||||
document.append(
|
||||
JsonPatchOperation(
|
||||
op="Add", path="/fields/System.Tags", value="Onefuzz"
|
||||
)
|
||||
)
|
||||
|
||||
for field in self.config.ado_fields:
|
||||
value = self.render(self.config.ado_fields[field])
|
||||
if field == "System.Tags":
|
||||
value += ";Onefuzz"
|
||||
document.append(
|
||||
JsonPatchOperation(op="Add", path="/fields/%s" % field, value=value)
|
||||
)
|
||||
return (task_type, document)
|
||||
|
||||
def create_new(self) -> WorkItem:
|
||||
task_type, document = self.render_new()
|
||||
|
||||
entry = self.client.create_work_item(
|
||||
document=document, project=self.project, type=task_type
|
||||
)
|
||||
|
||||
if self.config.comment:
|
||||
comment = self.render(self.config.comment)
|
||||
self.client.add_comment(
|
||||
CommentCreate(comment),
|
||||
self.project,
|
||||
entry.id,
|
||||
)
|
||||
return entry
|
||||
|
||||
def process(self, notification_info: str) -> None:
|
||||
seen = False
|
||||
for work_item in self.existing_work_items():
|
||||
self.update_existing(work_item, notification_info)
|
||||
seen = True
|
||||
|
||||
if not seen:
|
||||
entry = self.create_new()
|
||||
logging.info(
|
||||
"notify ado: created new work item" f" {entry.id} - {notification_info}"
|
||||
)
|
||||
|
||||
|
||||
def is_transient(err: Exception) -> bool:
|
||||
error_codes = [
|
||||
# "TF401349: An unexpected error has occurred, please verify your request and try again." # noqa: E501
|
||||
"TF401349",
|
||||
# TF26071: This work item has been changed by someone else since you opened it. You will need to refresh it and discard your changes. # noqa: E501
|
||||
"TF26071",
|
||||
]
|
||||
error_str = str(err)
|
||||
for code in error_codes:
|
||||
if code in error_str:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def notify_ado(
|
||||
config: ADOTemplate,
|
||||
container: Container,
|
||||
filename: str,
|
||||
report: Union[Report, RegressionReport],
|
||||
fail_task_on_transient_error: bool,
|
||||
notification_id: UUID,
|
||||
) -> None:
|
||||
if isinstance(report, RegressionReport):
|
||||
logging.info(
|
||||
"ado integration does not support regression reports. "
|
||||
"container:%s filename:%s",
|
||||
container,
|
||||
filename,
|
||||
)
|
||||
return
|
||||
|
||||
notification_info = (
|
||||
f"job_id:{report.job_id} task_id:{report.task_id}"
|
||||
f" container:{container} filename:{filename}"
|
||||
)
|
||||
|
||||
logging.info("notify ado: %s", notification_info)
|
||||
|
||||
try:
|
||||
ado = ADO(container, filename, config, report)
|
||||
ado.connect()
|
||||
ado.process(notification_info)
|
||||
except (
|
||||
AzureDevOpsAuthenticationError,
|
||||
AzureDevOpsClientError,
|
||||
AzureDevOpsServiceError,
|
||||
AzureDevOpsClientRequestError,
|
||||
ValueError,
|
||||
) as err:
|
||||
|
||||
if not fail_task_on_transient_error and is_transient(err):
|
||||
raise AdoNotificationException(
|
||||
f"transient ADO notification failure {notification_info}"
|
||||
) from err
|
||||
else:
|
||||
log_failed_notification(report, err, notification_id)
|
||||
raise AdoNotificationException(
|
||||
"Sending file changed event for notification %s to poison queue"
|
||||
% notification_id
|
||||
) from err
|
@ -1,96 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
from onefuzztypes.models import Report
|
||||
from onefuzztypes.primitives import Container
|
||||
|
||||
from ..azure.containers import auth_download_url
|
||||
from ..azure.creds import get_instance_url
|
||||
from ..jobs import Job
|
||||
from ..tasks.config import get_setup_container
|
||||
from ..tasks.main import Task
|
||||
|
||||
|
||||
def log_failed_notification(
|
||||
report: Report, error: Exception, notification_id: UUID
|
||||
) -> None:
|
||||
logging.error(
|
||||
"notification failed: notification_id:%s job_id:%s task_id:%s err:%s",
|
||||
notification_id,
|
||||
report.job_id,
|
||||
report.task_id,
|
||||
error,
|
||||
)
|
||||
|
||||
|
||||
class Render:
|
||||
def __init__(
|
||||
self,
|
||||
container: Container,
|
||||
filename: str,
|
||||
report: Report,
|
||||
*,
|
||||
task: Optional[Task] = None,
|
||||
job: Optional[Job] = None,
|
||||
target_url: Optional[str] = None,
|
||||
input_url: Optional[str] = None,
|
||||
report_url: Optional[str] = None,
|
||||
):
|
||||
self.report = report
|
||||
self.container = container
|
||||
self.filename = filename
|
||||
if not task:
|
||||
task = Task.get(report.job_id, report.task_id)
|
||||
if not task:
|
||||
raise ValueError(f"invalid task {report.task_id}")
|
||||
if not job:
|
||||
job = Job.get(report.job_id)
|
||||
if not job:
|
||||
raise ValueError(f"invalid job {report.job_id}")
|
||||
|
||||
self.task_config = task.config
|
||||
self.job_config = job.config
|
||||
self.env = SandboxedEnvironment()
|
||||
|
||||
self.target_url = target_url
|
||||
if not self.target_url:
|
||||
setup_container = get_setup_container(task.config)
|
||||
if setup_container:
|
||||
self.target_url = auth_download_url(
|
||||
setup_container, self.report.executable.replace("setup/", "", 1)
|
||||
)
|
||||
|
||||
if report_url:
|
||||
self.report_url = report_url
|
||||
else:
|
||||
self.report_url = auth_download_url(container, filename)
|
||||
|
||||
self.input_url = input_url
|
||||
if not self.input_url:
|
||||
if self.report.input_blob:
|
||||
self.input_url = auth_download_url(
|
||||
self.report.input_blob.container, self.report.input_blob.name
|
||||
)
|
||||
|
||||
def render(self, template: str) -> str:
|
||||
return self.env.from_string(template).render(
|
||||
{
|
||||
"report": self.report,
|
||||
"task": self.task_config,
|
||||
"job": self.job_config,
|
||||
"report_url": self.report_url,
|
||||
"input_url": self.input_url,
|
||||
"target_url": self.target_url,
|
||||
"report_container": self.container,
|
||||
"report_filename": self.filename,
|
||||
"repro_cmd": "onefuzz --endpoint %s repro create_and_connect %s %s"
|
||||
% (get_instance_url(), self.container, self.filename),
|
||||
}
|
||||
)
|
@ -1,138 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from github3 import login
|
||||
from github3.exceptions import GitHubException
|
||||
from github3.issues import Issue
|
||||
from onefuzztypes.enums import GithubIssueSearchMatch
|
||||
from onefuzztypes.models import (
|
||||
GithubAuth,
|
||||
GithubIssueTemplate,
|
||||
RegressionReport,
|
||||
Report,
|
||||
)
|
||||
from onefuzztypes.primitives import Container
|
||||
|
||||
from ..secrets import get_secret_obj
|
||||
from .common import Render, log_failed_notification
|
||||
|
||||
|
||||
class GithubIssue:
|
||||
def __init__(
|
||||
self,
|
||||
config: GithubIssueTemplate,
|
||||
container: Container,
|
||||
filename: str,
|
||||
report: Report,
|
||||
):
|
||||
self.config = config
|
||||
self.report = report
|
||||
if isinstance(config.auth.secret, GithubAuth):
|
||||
auth = config.auth.secret
|
||||
else:
|
||||
auth = get_secret_obj(config.auth.secret.url, GithubAuth)
|
||||
|
||||
self.gh = login(username=auth.user, password=auth.personal_access_token)
|
||||
self.renderer = Render(container, filename, report)
|
||||
|
||||
def render(self, field: str) -> str:
|
||||
return self.renderer.render(field)
|
||||
|
||||
def existing(self) -> List[Issue]:
|
||||
query = [
|
||||
self.render(self.config.unique_search.string),
|
||||
"repo:%s/%s"
|
||||
% (
|
||||
self.render(self.config.organization),
|
||||
self.render(self.config.repository),
|
||||
),
|
||||
]
|
||||
if self.config.unique_search.author:
|
||||
query.append("author:%s" % self.render(self.config.unique_search.author))
|
||||
|
||||
if self.config.unique_search.state:
|
||||
query.append("state:%s" % self.config.unique_search.state.name)
|
||||
|
||||
issues = []
|
||||
title = self.render(self.config.title)
|
||||
body = self.render(self.config.body)
|
||||
for issue in self.gh.search_issues(" ".join(query)):
|
||||
skip = False
|
||||
for field in self.config.unique_search.field_match:
|
||||
if field == GithubIssueSearchMatch.title and issue.title != title:
|
||||
skip = True
|
||||
break
|
||||
if field == GithubIssueSearchMatch.body and issue.body != body:
|
||||
skip = True
|
||||
break
|
||||
if not skip:
|
||||
issues.append(issue)
|
||||
|
||||
return issues
|
||||
|
||||
def update(self, issue: Issue) -> None:
|
||||
logging.info("updating issue: %s", issue)
|
||||
if self.config.on_duplicate.comment:
|
||||
issue.issue.create_comment(self.render(self.config.on_duplicate.comment))
|
||||
if self.config.on_duplicate.labels:
|
||||
labels = [self.render(x) for x in self.config.on_duplicate.labels]
|
||||
issue.issue.edit(labels=labels)
|
||||
if self.config.on_duplicate.reopen and issue.state != "open":
|
||||
issue.issue.edit(state="open")
|
||||
|
||||
def create(self) -> None:
|
||||
logging.info("creating issue")
|
||||
|
||||
assignees = [self.render(x) for x in self.config.assignees]
|
||||
labels = list(set(["OneFuzz"] + [self.render(x) for x in self.config.labels]))
|
||||
|
||||
self.gh.create_issue(
|
||||
self.render(self.config.organization),
|
||||
self.render(self.config.repository),
|
||||
self.render(self.config.title),
|
||||
body=self.render(self.config.body),
|
||||
labels=labels,
|
||||
assignees=assignees,
|
||||
)
|
||||
|
||||
def process(self) -> None:
|
||||
issues = self.existing()
|
||||
if issues:
|
||||
self.update(issues[0])
|
||||
else:
|
||||
self.create()
|
||||
|
||||
|
||||
def github_issue(
|
||||
config: GithubIssueTemplate,
|
||||
container: Container,
|
||||
filename: str,
|
||||
report: Optional[Union[Report, RegressionReport]],
|
||||
notification_id: UUID,
|
||||
) -> None:
|
||||
if report is None:
|
||||
return
|
||||
if isinstance(report, RegressionReport):
|
||||
logging.info(
|
||||
"github issue integration does not support regression reports. "
|
||||
"container:%s filename:%s",
|
||||
container,
|
||||
filename,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
handler = GithubIssue(config, container, filename, report)
|
||||
handler.process()
|
||||
except (GitHubException, ValueError) as err:
|
||||
log_failed_notification(report, err, notification_id)
|
||||
raise GitHubException(
|
||||
"Sending file change event for notification %s to poison queue"
|
||||
% notification_id
|
||||
) from err
|
@ -1,200 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from memoization import cached
|
||||
from onefuzztypes import models
|
||||
from onefuzztypes.enums import ErrorCode, TaskState
|
||||
from onefuzztypes.events import (
|
||||
EventCrashReported,
|
||||
EventFileAdded,
|
||||
EventRegressionReported,
|
||||
)
|
||||
from onefuzztypes.models import (
|
||||
ADOTemplate,
|
||||
Error,
|
||||
GithubIssueTemplate,
|
||||
NotificationTemplate,
|
||||
RegressionReport,
|
||||
Report,
|
||||
Result,
|
||||
TeamsTemplate,
|
||||
)
|
||||
from onefuzztypes.primitives import Container
|
||||
|
||||
from ..azure.containers import container_exists, get_file_sas_url
|
||||
from ..azure.queue import send_message
|
||||
from ..azure.storage import StorageType
|
||||
from ..events import send_event
|
||||
from ..orm import ORMMixin
|
||||
from ..reports import get_report_or_regression
|
||||
from ..tasks.config import get_input_container_queues
|
||||
from ..tasks.main import Task
|
||||
from .ado import notify_ado
|
||||
from .github_issues import github_issue
|
||||
from .teams import notify_teams
|
||||
|
||||
|
||||
class Notification(models.Notification, ORMMixin):
|
||||
@classmethod
|
||||
def get_by_id(cls, notification_id: UUID) -> Result["Notification"]:
|
||||
notifications = cls.search(query={"notification_id": [notification_id]})
|
||||
if not notifications:
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST, errors=["unable to find Notification"]
|
||||
)
|
||||
|
||||
if len(notifications) != 1:
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
errors=["error identifying Notification"],
|
||||
)
|
||||
notification = notifications[0]
|
||||
return notification
|
||||
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, str]:
|
||||
return ("notification_id", "container")
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls, container: Container, config: NotificationTemplate, replace_existing: bool
|
||||
) -> Result["Notification"]:
|
||||
if not container_exists(container, StorageType.corpus):
|
||||
return Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid container"])
|
||||
|
||||
if replace_existing:
|
||||
existing = cls.search(query={"container": [container]})
|
||||
for entry in existing:
|
||||
logging.info(
|
||||
"replacing existing notification: %s - %s",
|
||||
entry.notification_id,
|
||||
container,
|
||||
)
|
||||
entry.delete()
|
||||
|
||||
entry = cls(container=container, config=config)
|
||||
entry.save()
|
||||
logging.info(
|
||||
"created notification. notification_id:%s container:%s",
|
||||
entry.notification_id,
|
||||
entry.container,
|
||||
)
|
||||
return entry
|
||||
|
||||
|
||||
@cached(ttl=10)
|
||||
def get_notifications(container: Container) -> List[Notification]:
|
||||
return Notification.search(query={"container": [container]})
|
||||
|
||||
|
||||
def get_regression_report_task(report: RegressionReport) -> Optional[Task]:
|
||||
# crash_test_result is required, but report & no_repro are not
|
||||
if report.crash_test_result.crash_report:
|
||||
return Task.get(
|
||||
report.crash_test_result.crash_report.job_id,
|
||||
report.crash_test_result.crash_report.task_id,
|
||||
)
|
||||
if report.crash_test_result.no_repro:
|
||||
return Task.get(
|
||||
report.crash_test_result.no_repro.job_id,
|
||||
report.crash_test_result.no_repro.task_id,
|
||||
)
|
||||
|
||||
logging.error(
|
||||
"unable to find crash_report or no_repro entry for report: %s",
|
||||
report.json(include_none=False),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@cached(ttl=10)
|
||||
def get_queue_tasks() -> Sequence[Tuple[Task, Sequence[str]]]:
|
||||
results = []
|
||||
for task in Task.search_states(states=TaskState.available()):
|
||||
containers = get_input_container_queues(task.config)
|
||||
if containers:
|
||||
results.append((task, containers))
|
||||
return results
|
||||
|
||||
|
||||
def new_files(
|
||||
container: Container, filename: str, fail_task_on_transient_error: bool
|
||||
) -> None:
|
||||
notifications = get_notifications(container)
|
||||
|
||||
report = get_report_or_regression(
|
||||
container, filename, expect_reports=bool(notifications)
|
||||
)
|
||||
|
||||
if notifications:
|
||||
done = []
|
||||
for notification in notifications:
|
||||
# ignore duplicate configurations
|
||||
if notification.config in done:
|
||||
continue
|
||||
done.append(notification.config)
|
||||
|
||||
if isinstance(notification.config, TeamsTemplate):
|
||||
notify_teams(
|
||||
notification.config,
|
||||
container,
|
||||
filename,
|
||||
report,
|
||||
notification.notification_id,
|
||||
)
|
||||
|
||||
if not report:
|
||||
continue
|
||||
|
||||
if isinstance(notification.config, ADOTemplate):
|
||||
notify_ado(
|
||||
notification.config,
|
||||
container,
|
||||
filename,
|
||||
report,
|
||||
fail_task_on_transient_error,
|
||||
notification.notification_id,
|
||||
)
|
||||
|
||||
if isinstance(notification.config, GithubIssueTemplate):
|
||||
github_issue(
|
||||
notification.config,
|
||||
container,
|
||||
filename,
|
||||
report,
|
||||
notification.notification_id,
|
||||
)
|
||||
|
||||
for (task, containers) in get_queue_tasks():
|
||||
if container in containers:
|
||||
logging.info("queuing input %s %s %s", container, filename, task.task_id)
|
||||
url = get_file_sas_url(
|
||||
container, filename, StorageType.corpus, read=True, delete=True
|
||||
)
|
||||
send_message(task.task_id, bytes(url, "utf-8"), StorageType.corpus)
|
||||
|
||||
if isinstance(report, Report):
|
||||
crash_report_event = EventCrashReported(
|
||||
report=report, container=container, filename=filename
|
||||
)
|
||||
report_task = Task.get(report.job_id, report.task_id)
|
||||
if report_task:
|
||||
crash_report_event.task_config = report_task.config
|
||||
send_event(crash_report_event)
|
||||
elif isinstance(report, RegressionReport):
|
||||
regression_event = EventRegressionReported(
|
||||
regression_report=report, container=container, filename=filename
|
||||
)
|
||||
|
||||
report_task = get_regression_report_task(report)
|
||||
if report_task:
|
||||
regression_event.task_config = report_task.config
|
||||
send_event(regression_event)
|
||||
else:
|
||||
send_event(EventFileAdded(container=container, filename=filename))
|
@ -1,141 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
from onefuzztypes.models import RegressionReport, Report, TeamsTemplate
|
||||
from onefuzztypes.primitives import Container
|
||||
|
||||
from ..azure.containers import auth_download_url
|
||||
from ..secrets import get_secret_string_value
|
||||
from ..tasks.config import get_setup_container
|
||||
from ..tasks.main import Task
|
||||
|
||||
|
||||
def markdown_escape(data: str) -> str:
|
||||
values = r"\\*_{}[]()#+-.!" # noqa: P103
|
||||
for value in values:
|
||||
data = data.replace(value, "\\" + value)
|
||||
data = data.replace("`", "``")
|
||||
return data
|
||||
|
||||
|
||||
def code_block(data: str) -> str:
|
||||
data = data.replace("`", "``")
|
||||
return "\n```\n%s\n```\n" % data
|
||||
|
||||
|
||||
def send_teams_webhook(
|
||||
config: TeamsTemplate,
|
||||
title: str,
|
||||
facts: List[Dict[str, str]],
|
||||
text: Optional[str],
|
||||
notification_id: UUID,
|
||||
) -> None:
|
||||
title = markdown_escape(title)
|
||||
|
||||
message: Dict[str, Any] = {
|
||||
"@type": "MessageCard",
|
||||
"@context": "https://schema.org/extensions",
|
||||
"summary": title,
|
||||
"sections": [{"activityTitle": title, "facts": facts}],
|
||||
}
|
||||
|
||||
if text:
|
||||
message["sections"].append({"text": text})
|
||||
|
||||
config_url = get_secret_string_value(config.url)
|
||||
response = requests.post(config_url, json=message)
|
||||
if not response.ok:
|
||||
logging.error(
|
||||
"webhook failed notification_id:%s %s %s",
|
||||
notification_id,
|
||||
response.status_code,
|
||||
response.content,
|
||||
)
|
||||
|
||||
|
||||
def notify_teams(
|
||||
config: TeamsTemplate,
|
||||
container: Container,
|
||||
filename: str,
|
||||
report: Optional[Union[Report, RegressionReport]],
|
||||
notification_id: UUID,
|
||||
) -> None:
|
||||
text = None
|
||||
facts: List[Dict[str, str]] = []
|
||||
|
||||
if isinstance(report, Report):
|
||||
task = Task.get(report.job_id, report.task_id)
|
||||
if not task:
|
||||
logging.error(
|
||||
"report with invalid task %s:%s", report.job_id, report.task_id
|
||||
)
|
||||
return
|
||||
|
||||
title = "new crash in %s: %s @ %s" % (
|
||||
report.executable,
|
||||
report.crash_type,
|
||||
report.crash_site,
|
||||
)
|
||||
|
||||
links = [
|
||||
"[report](%s)" % auth_download_url(container, filename),
|
||||
]
|
||||
|
||||
setup_container = get_setup_container(task.config)
|
||||
if setup_container:
|
||||
links.append(
|
||||
"[executable](%s)"
|
||||
% auth_download_url(
|
||||
setup_container,
|
||||
report.executable.replace("setup/", "", 1),
|
||||
),
|
||||
)
|
||||
|
||||
if report.input_blob:
|
||||
links.append(
|
||||
"[input](%s)"
|
||||
% auth_download_url(
|
||||
report.input_blob.container, report.input_blob.name
|
||||
),
|
||||
)
|
||||
|
||||
facts += [
|
||||
{"name": "Files", "value": " | ".join(links)},
|
||||
{
|
||||
"name": "Task",
|
||||
"value": markdown_escape(
|
||||
"job_id: %s task_id: %s" % (report.job_id, report.task_id)
|
||||
),
|
||||
},
|
||||
{
|
||||
"name": "Repro",
|
||||
"value": code_block(
|
||||
"onefuzz repro create_and_connect %s %s" % (container, filename)
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
text = "## Call Stack\n" + "\n".join(code_block(x) for x in report.call_stack)
|
||||
|
||||
else:
|
||||
title = "new file found"
|
||||
facts += [
|
||||
{
|
||||
"name": "file",
|
||||
"value": "[%s/%s](%s)"
|
||||
% (
|
||||
markdown_escape(container),
|
||||
markdown_escape(filename),
|
||||
auth_download_url(container, filename),
|
||||
),
|
||||
}
|
||||
]
|
||||
|
||||
send_teams_webhook(config, title, facts, text, notification_id)
|
@ -1,505 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from uuid import UUID
|
||||
|
||||
from azure.common import AzureConflictHttpError, AzureMissingResourceHttpError
|
||||
from onefuzztypes.enums import (
|
||||
ErrorCode,
|
||||
JobState,
|
||||
NodeState,
|
||||
PoolState,
|
||||
ScalesetState,
|
||||
TaskState,
|
||||
TelemetryEvent,
|
||||
UpdateType,
|
||||
VmState,
|
||||
)
|
||||
from onefuzztypes.models import Error, SecretData
|
||||
from onefuzztypes.primitives import Container, PoolName, Region
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from .azure.table import get_client
|
||||
from .secrets import delete_remote_secret_data, save_to_keyvault
|
||||
from .telemetry import track_event_filtered
|
||||
from .updates import queue_update
|
||||
|
||||
A = TypeVar("A", bound="ORMMixin")
|
||||
|
||||
QUERY_VALUE_TYPES = Union[
|
||||
List[bool],
|
||||
List[int],
|
||||
List[str],
|
||||
List[UUID],
|
||||
List[Region],
|
||||
List[Container],
|
||||
List[PoolName],
|
||||
List[VmState],
|
||||
List[ScalesetState],
|
||||
List[JobState],
|
||||
List[TaskState],
|
||||
List[PoolState],
|
||||
List[NodeState],
|
||||
]
|
||||
QueryFilter = Dict[str, QUERY_VALUE_TYPES]
|
||||
|
||||
SAFE_STRINGS = (UUID, Container, Region, PoolName)
|
||||
KEY = Union[int, str, UUID, Enum]
|
||||
|
||||
HOURS = 60 * 60
|
||||
|
||||
|
||||
class HasState(Protocol):
|
||||
# TODO: this should be bound tighter than Any
|
||||
# In the end, we want this to be an Enum. Specifically, one of
|
||||
# the JobState,TaskState,etc enums.
|
||||
state: Any
|
||||
|
||||
def get_keys(self) -> Tuple[KEY, KEY]:
|
||||
...
|
||||
|
||||
|
||||
def process_state_update(obj: HasState) -> None:
|
||||
"""
|
||||
process a single state update, if the obj
|
||||
implements a function for that state
|
||||
"""
|
||||
|
||||
func = getattr(obj, obj.state.name, None)
|
||||
if func is None:
|
||||
return
|
||||
|
||||
keys = obj.get_keys()
|
||||
|
||||
logging.info(
|
||||
"processing state update: %s - %s - %s", type(obj), keys, obj.state.name
|
||||
)
|
||||
func()
|
||||
|
||||
|
||||
def process_state_updates(obj: HasState, max_updates: int = 5) -> None:
|
||||
"""process through the state machine for an object"""
|
||||
|
||||
for _ in range(max_updates):
|
||||
state = obj.state
|
||||
process_state_update(obj)
|
||||
new_state = obj.state
|
||||
if new_state == state:
|
||||
break
|
||||
|
||||
|
||||
def resolve(key: KEY) -> str:
|
||||
if isinstance(key, str):
|
||||
return key
|
||||
elif isinstance(key, UUID):
|
||||
return str(key)
|
||||
elif isinstance(key, Enum):
|
||||
return key.name
|
||||
elif isinstance(key, int):
|
||||
return str(key)
|
||||
raise NotImplementedError("unsupported type %s - %s" % (type(key), repr(key)))
|
||||
|
||||
|
||||
def build_filters(
|
||||
cls: Type[A], query_args: Optional[QueryFilter]
|
||||
) -> Tuple[Optional[str], QueryFilter]:
|
||||
if not query_args:
|
||||
return (None, {})
|
||||
|
||||
partition_key_field, row_key_field = cls.key_fields()
|
||||
|
||||
search_filter_parts = []
|
||||
post_filters: QueryFilter = {}
|
||||
|
||||
for field, values in query_args.items():
|
||||
if field not in cls.__fields__:
|
||||
raise ValueError("unexpected field %s: %s" % (repr(field), cls))
|
||||
|
||||
if not values:
|
||||
continue
|
||||
|
||||
if field == partition_key_field:
|
||||
field_name = "PartitionKey"
|
||||
elif field == row_key_field:
|
||||
field_name = "RowKey"
|
||||
else:
|
||||
field_name = field
|
||||
|
||||
parts: Optional[List[str]] = None
|
||||
if isinstance(values[0], bool):
|
||||
parts = []
|
||||
for x in values:
|
||||
if not isinstance(x, bool):
|
||||
raise TypeError("unexpected type")
|
||||
parts.append("%s eq %s" % (field_name, str(x).lower()))
|
||||
elif isinstance(values[0], int):
|
||||
parts = []
|
||||
for x in values:
|
||||
if not isinstance(x, int):
|
||||
raise TypeError("unexpected type")
|
||||
parts.append("%s eq %d" % (field_name, x))
|
||||
elif isinstance(values[0], Enum):
|
||||
parts = []
|
||||
for x in values:
|
||||
if not isinstance(x, Enum):
|
||||
raise TypeError("unexpected type")
|
||||
parts.append("%s eq '%s'" % (field_name, x.name))
|
||||
elif all(isinstance(x, SAFE_STRINGS) for x in values):
|
||||
parts = ["%s eq '%s'" % (field_name, x) for x in values]
|
||||
else:
|
||||
post_filters[field_name] = values
|
||||
|
||||
if parts:
|
||||
if len(parts) == 1:
|
||||
search_filter_parts.append(parts[0])
|
||||
else:
|
||||
search_filter_parts.append("(" + " or ".join(parts) + ")")
|
||||
|
||||
if search_filter_parts:
|
||||
return (" and ".join(search_filter_parts), post_filters)
|
||||
|
||||
return (None, post_filters)
|
||||
|
||||
|
||||
def post_filter(value: Any, filters: Optional[QueryFilter]) -> bool:
|
||||
if not filters:
|
||||
return True
|
||||
|
||||
for field in filters:
|
||||
if field not in value:
|
||||
return False
|
||||
if value[field] not in filters[field]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
MappingIntStrAny = Mapping[Union[int, str], Any]
|
||||
# A = TypeVar("A", bound="Model")
|
||||
|
||||
|
||||
class ModelMixin(BaseModel):
|
||||
def export_exclude(self) -> Optional[MappingIntStrAny]:
|
||||
return None
|
||||
|
||||
def raw(
|
||||
self,
|
||||
*,
|
||||
by_alias: bool = False,
|
||||
exclude_none: bool = False,
|
||||
exclude: MappingIntStrAny = None,
|
||||
include: MappingIntStrAny = None,
|
||||
) -> Dict[str, Any]:
|
||||
# cycling through json means all wrapped types get resolved, such as UUID
|
||||
result: Dict[str, Any] = json.loads(
|
||||
self.json(
|
||||
by_alias=by_alias,
|
||||
exclude_none=exclude_none,
|
||||
exclude=exclude,
|
||||
include=include,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
B = TypeVar("B", bound=BaseModel)
|
||||
|
||||
|
||||
def hide_secrets(data: B, hider: Callable[[SecretData], SecretData]) -> B:
|
||||
for field in data.__fields__:
|
||||
field_data = getattr(data, field)
|
||||
|
||||
if isinstance(field_data, SecretData):
|
||||
field_data = hider(field_data)
|
||||
elif isinstance(field_data, BaseModel):
|
||||
field_data = hide_secrets(field_data, hider)
|
||||
elif isinstance(field_data, list):
|
||||
field_data = [
|
||||
hide_secrets(x, hider) if isinstance(x, BaseModel) else x
|
||||
for x in field_data
|
||||
]
|
||||
elif isinstance(field_data, dict):
|
||||
for key in field_data:
|
||||
if not isinstance(field_data[key], BaseModel):
|
||||
continue
|
||||
field_data[key] = hide_secrets(field_data[key], hider)
|
||||
|
||||
setattr(data, field, field_data)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# NOTE: the actual deletion must come from the `deleter` callback function
|
||||
def delete_secrets(data: B, deleter: Callable[[SecretData], None]) -> None:
|
||||
for field in data.__fields__:
|
||||
field_data = getattr(data, field)
|
||||
if isinstance(field_data, SecretData):
|
||||
deleter(field_data)
|
||||
elif isinstance(field_data, BaseModel):
|
||||
delete_secrets(field_data, deleter)
|
||||
elif isinstance(field_data, list):
|
||||
for entry in field_data:
|
||||
if isinstance(entry, BaseModel):
|
||||
delete_secrets(entry, deleter)
|
||||
elif isinstance(entry, SecretData):
|
||||
deleter(entry)
|
||||
elif isinstance(field_data, dict):
|
||||
for value in field_data.values():
|
||||
if isinstance(value, BaseModel):
|
||||
delete_secrets(value, deleter)
|
||||
elif isinstance(value, SecretData):
|
||||
deleter(value)
|
||||
|
||||
|
||||
# NOTE: if you want to include Timestamp in a model that uses ORMMixin,
|
||||
# it must be maintained as part of the model.
|
||||
class ORMMixin(ModelMixin):
|
||||
etag: Optional[str]
|
||||
|
||||
@classmethod
|
||||
def table_name(cls: Type[A]) -> str:
|
||||
return cls.__name__
|
||||
|
||||
@classmethod
|
||||
def get(
|
||||
cls: Type[A], PartitionKey: KEY, RowKey: Optional[KEY] = None
|
||||
) -> Optional[A]:
|
||||
client = get_client(table=cls.table_name())
|
||||
partition_key = resolve(PartitionKey)
|
||||
row_key = resolve(RowKey) if RowKey else partition_key
|
||||
|
||||
try:
|
||||
raw = client.get_entity(cls.table_name(), partition_key, row_key)
|
||||
except AzureMissingResourceHttpError:
|
||||
return None
|
||||
return cls.load(raw)
|
||||
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, Optional[str]]:
|
||||
raise NotImplementedError("keys not defined")
|
||||
|
||||
# FILTERS:
|
||||
# The following
|
||||
# * save_exclude: Specify fields to *exclude* from saving to Storage Tables
|
||||
# * export_exclude: Specify the fields to *exclude* from sending to an external API
|
||||
# * telemetry_include: Specify the fields to *include* for telemetry
|
||||
#
|
||||
# For implementation details see:
|
||||
# https://pydantic-docs.helpmanual.io/usage/exporting_models/#advanced-include-and-exclude
|
||||
def save_exclude(self) -> Optional[MappingIntStrAny]:
|
||||
return None
|
||||
|
||||
def export_exclude(self) -> Optional[MappingIntStrAny]:
|
||||
return {"etag": ...}
|
||||
|
||||
def telemetry_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {}
|
||||
|
||||
def telemetry(self) -> Any:
|
||||
return self.raw(exclude_none=True, include=self.telemetry_include())
|
||||
|
||||
def get_keys(self) -> Tuple[KEY, KEY]:
|
||||
partition_key_field, row_key_field = self.key_fields()
|
||||
|
||||
partition_key = getattr(self, partition_key_field)
|
||||
if row_key_field:
|
||||
row_key = getattr(self, row_key_field)
|
||||
else:
|
||||
row_key = partition_key
|
||||
|
||||
return (partition_key, row_key)
|
||||
|
||||
def save(self, new: bool = False, require_etag: bool = False) -> Optional[Error]:
|
||||
self = hide_secrets(self, save_to_keyvault)
|
||||
# TODO: migrate to an inspect.signature() model
|
||||
raw = self.raw(by_alias=True, exclude_none=True, exclude=self.save_exclude())
|
||||
for key in raw:
|
||||
if not isinstance(raw[key], (str, int)):
|
||||
raw[key] = json.dumps(raw[key])
|
||||
|
||||
for field in self.__fields__:
|
||||
if field not in raw:
|
||||
continue
|
||||
# for datetime fields that passed through filtering, use the real value,
|
||||
# rather than a serialized form
|
||||
if self.__fields__[field].type_ == datetime:
|
||||
raw[field] = getattr(self, field)
|
||||
|
||||
partition_key_field, row_key_field = self.key_fields()
|
||||
|
||||
# PartitionKey and RowKey must be 'str'
|
||||
raw["PartitionKey"] = resolve(raw[partition_key_field])
|
||||
raw["RowKey"] = resolve(raw[row_key_field or partition_key_field])
|
||||
|
||||
del raw[partition_key_field]
|
||||
if row_key_field in raw:
|
||||
del raw[row_key_field]
|
||||
|
||||
client = get_client(table=self.table_name())
|
||||
|
||||
# never save the timestamp
|
||||
if "Timestamp" in raw:
|
||||
del raw["Timestamp"]
|
||||
|
||||
if new:
|
||||
try:
|
||||
self.etag = client.insert_entity(self.table_name(), raw)
|
||||
except AzureConflictHttpError:
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_CREATE, errors=["entry already exists"]
|
||||
)
|
||||
elif self.etag and require_etag:
|
||||
self.etag = client.replace_entity(
|
||||
self.table_name(), raw, if_match=self.etag
|
||||
)
|
||||
else:
|
||||
self.etag = client.insert_or_replace_entity(self.table_name(), raw)
|
||||
|
||||
if self.table_name() in TelemetryEvent.__members__:
|
||||
telem = self.telemetry()
|
||||
if telem:
|
||||
track_event_filtered(TelemetryEvent[self.table_name()], telem)
|
||||
|
||||
return None
|
||||
|
||||
def delete(self) -> None:
|
||||
partition_key, row_key = self.get_keys()
|
||||
|
||||
delete_secrets(self, delete_remote_secret_data)
|
||||
|
||||
client = get_client()
|
||||
try:
|
||||
client.delete_entity(
|
||||
self.table_name(), resolve(partition_key), resolve(row_key)
|
||||
)
|
||||
except AzureMissingResourceHttpError:
|
||||
# It's OK if the component is already deleted
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def load(cls: Type[A], data: Dict[str, Union[str, bytes, bytearray]]) -> A:
|
||||
partition_key_field, row_key_field = cls.key_fields()
|
||||
|
||||
if partition_key_field in data:
|
||||
raise Exception(
|
||||
"duplicate PartitionKey field %s for %s"
|
||||
% (partition_key_field, cls.table_name())
|
||||
)
|
||||
if row_key_field in data:
|
||||
raise Exception(
|
||||
"duplicate RowKey field %s for %s" % (row_key_field, cls.table_name())
|
||||
)
|
||||
|
||||
data[partition_key_field] = data["PartitionKey"]
|
||||
if row_key_field is not None:
|
||||
data[row_key_field] = data["RowKey"]
|
||||
|
||||
del data["PartitionKey"]
|
||||
del data["RowKey"]
|
||||
|
||||
for key in inspect.signature(cls).parameters:
|
||||
if key not in data:
|
||||
continue
|
||||
|
||||
annotation = inspect.signature(cls).parameters[key].annotation
|
||||
|
||||
if inspect.isclass(annotation):
|
||||
if (
|
||||
issubclass(annotation, BaseModel)
|
||||
or issubclass(annotation, dict)
|
||||
or issubclass(annotation, list)
|
||||
):
|
||||
data[key] = json.loads(data[key])
|
||||
continue
|
||||
|
||||
if getattr(annotation, "__origin__", None) == Union and any(
|
||||
inspect.isclass(x) and issubclass(x, BaseModel)
|
||||
for x in annotation.__args__
|
||||
):
|
||||
data[key] = json.loads(data[key])
|
||||
continue
|
||||
|
||||
# Required for Python >=3.7. In 3.6, a `Dict[_,_]` and `List[_]` annotations
|
||||
# are a class according to `inspect.isclass`.
|
||||
if getattr(annotation, "__origin__", None) in [dict, list]:
|
||||
data[key] = json.loads(data[key])
|
||||
continue
|
||||
|
||||
return cls.parse_obj(data)
|
||||
|
||||
@classmethod
|
||||
def search(
|
||||
cls: Type[A],
|
||||
*,
|
||||
query: Optional[QueryFilter] = None,
|
||||
raw_unchecked_filter: Optional[str] = None,
|
||||
num_results: int = None,
|
||||
) -> List[A]:
|
||||
search_filter, post_filters = build_filters(cls, query)
|
||||
|
||||
if raw_unchecked_filter is not None:
|
||||
if search_filter is None:
|
||||
search_filter = raw_unchecked_filter
|
||||
else:
|
||||
search_filter = "(%s) and (%s)" % (search_filter, raw_unchecked_filter)
|
||||
|
||||
client = get_client(table=cls.table_name())
|
||||
entries = []
|
||||
for row in client.query_entities(
|
||||
cls.table_name(), filter=search_filter, num_results=num_results
|
||||
):
|
||||
if not post_filter(row, post_filters):
|
||||
continue
|
||||
|
||||
entry = cls.load(row)
|
||||
entries.append(entry)
|
||||
return entries
|
||||
|
||||
def queue(
|
||||
self,
|
||||
*,
|
||||
method: Optional[Callable] = None,
|
||||
visibility_timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
if not hasattr(self, "state"):
|
||||
raise NotImplementedError("Queued an ORM mapping without State")
|
||||
|
||||
update_type = UpdateType.__members__.get(type(self).__name__)
|
||||
if update_type is None:
|
||||
raise NotImplementedError("unsupported update type: %s" % self)
|
||||
|
||||
method_name: Optional[str] = None
|
||||
if method is not None:
|
||||
if not hasattr(method, "__name__"):
|
||||
raise Exception("unable to queue method: %s" % method)
|
||||
method_name = method.__name__
|
||||
|
||||
partition_key, row_key = self.get_keys()
|
||||
|
||||
queue_update(
|
||||
update_type,
|
||||
resolve(partition_key),
|
||||
resolve(row_key),
|
||||
method=method_name,
|
||||
visibility_timeout=visibility_timeout,
|
||||
)
|
@ -1,346 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import base58
|
||||
from azure.mgmt.compute.models import VirtualMachine
|
||||
from onefuzztypes.enums import ErrorCode, VmState
|
||||
from onefuzztypes.events import (
|
||||
EventProxyCreated,
|
||||
EventProxyDeleted,
|
||||
EventProxyFailed,
|
||||
EventProxyStateUpdated,
|
||||
)
|
||||
from onefuzztypes.models import (
|
||||
Authentication,
|
||||
Error,
|
||||
Forward,
|
||||
ProxyConfig,
|
||||
ProxyHeartbeat,
|
||||
)
|
||||
from onefuzztypes.primitives import Container, Region
|
||||
from pydantic import Field
|
||||
|
||||
from .__version__ import __version__
|
||||
from .azure.auth import build_auth
|
||||
from .azure.containers import get_file_sas_url, save_blob
|
||||
from .azure.creds import get_instance_id
|
||||
from .azure.ip import get_public_ip
|
||||
from .azure.nsg import NSG
|
||||
from .azure.queue import get_queue_sas
|
||||
from .azure.storage import StorageType
|
||||
from .azure.vm import VM
|
||||
from .config import InstanceConfig
|
||||
from .events import send_event
|
||||
from .extension import proxy_manager_extensions
|
||||
from .orm import ORMMixin, QueryFilter
|
||||
from .proxy_forward import ProxyForward
|
||||
|
||||
PROXY_LOG_PREFIX = "scaleset-proxy: "
|
||||
PROXY_LIFESPAN = datetime.timedelta(days=7)
|
||||
|
||||
|
||||
# This isn't intended to ever be shared to the client, hence not being in
|
||||
# onefuzztypes
|
||||
class Proxy(ORMMixin):
|
||||
timestamp: Optional[datetime.datetime] = Field(alias="Timestamp")
|
||||
created_timestamp: datetime.datetime = Field(
|
||||
default_factory=datetime.datetime.utcnow
|
||||
)
|
||||
proxy_id: UUID = Field(default_factory=uuid4)
|
||||
region: Region
|
||||
state: VmState = Field(default=VmState.init)
|
||||
auth: Authentication = Field(default_factory=build_auth)
|
||||
ip: Optional[str]
|
||||
error: Optional[Error]
|
||||
version: str = Field(default=__version__)
|
||||
heartbeat: Optional[ProxyHeartbeat]
|
||||
outdated: bool = Field(default=False)
|
||||
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, Optional[str]]:
|
||||
return ("region", "proxy_id")
|
||||
|
||||
def get_vm(self, config: InstanceConfig) -> VM:
|
||||
config = InstanceConfig.fetch()
|
||||
sku = config.proxy_vm_sku
|
||||
tags = None
|
||||
if config.vm_tags:
|
||||
tags = config.vm_tags
|
||||
vm = VM(
|
||||
name="proxy-%s" % base58.b58encode(self.proxy_id.bytes).decode(),
|
||||
region=self.region,
|
||||
sku=sku,
|
||||
image=config.default_linux_vm_image,
|
||||
auth=self.auth,
|
||||
tags=tags,
|
||||
)
|
||||
return vm
|
||||
|
||||
def init(self) -> None:
|
||||
config = InstanceConfig.fetch()
|
||||
vm = self.get_vm(config)
|
||||
vm_data = vm.get()
|
||||
if vm_data:
|
||||
if vm_data.provisioning_state == "Failed":
|
||||
self.set_provision_failed(vm_data)
|
||||
return
|
||||
else:
|
||||
self.save_proxy_config()
|
||||
self.set_state(VmState.extensions_launch)
|
||||
else:
|
||||
nsg = NSG(
|
||||
name=self.region,
|
||||
region=self.region,
|
||||
)
|
||||
|
||||
result = nsg.create()
|
||||
if isinstance(result, Error):
|
||||
self.set_failed(result)
|
||||
return
|
||||
|
||||
nsg_config = config.proxy_nsg_config
|
||||
result = nsg.set_allowed_sources(nsg_config)
|
||||
if isinstance(result, Error):
|
||||
self.set_failed(result)
|
||||
return
|
||||
|
||||
vm.nsg = nsg
|
||||
|
||||
result = vm.create()
|
||||
if isinstance(result, Error):
|
||||
self.set_failed(result)
|
||||
return
|
||||
self.save()
|
||||
|
||||
def set_provision_failed(self, vm_data: VirtualMachine) -> None:
|
||||
errors = ["provisioning failed"]
|
||||
for status in vm_data.instance_view.statuses:
|
||||
if status.level.name.lower() == "error":
|
||||
errors.append(
|
||||
f"code:{status.code} status:{status.display_status} "
|
||||
f"message:{status.message}"
|
||||
)
|
||||
|
||||
self.set_failed(
|
||||
Error(
|
||||
code=ErrorCode.PROXY_FAILED,
|
||||
errors=errors,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
def set_failed(self, error: Error) -> None:
|
||||
if self.error is not None:
|
||||
return
|
||||
|
||||
logging.error(PROXY_LOG_PREFIX + "vm failed: %s - %s", self.region, error)
|
||||
send_event(
|
||||
EventProxyFailed(region=self.region, proxy_id=self.proxy_id, error=error)
|
||||
)
|
||||
self.error = error
|
||||
self.set_state(VmState.stopping)
|
||||
|
||||
def extensions_launch(self) -> None:
|
||||
config = InstanceConfig.fetch()
|
||||
vm = self.get_vm(config)
|
||||
vm_data = vm.get()
|
||||
if not vm_data:
|
||||
self.set_failed(
|
||||
Error(
|
||||
code=ErrorCode.PROXY_FAILED,
|
||||
errors=["azure not able to find vm"],
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if vm_data.provisioning_state == "Failed":
|
||||
self.set_provision_failed(vm_data)
|
||||
return
|
||||
|
||||
ip = get_public_ip(vm_data.network_profile.network_interfaces[0].id)
|
||||
if ip is None:
|
||||
self.save()
|
||||
return
|
||||
self.ip = ip
|
||||
|
||||
extensions = proxy_manager_extensions(self.region, self.proxy_id)
|
||||
result = vm.add_extensions(extensions)
|
||||
if isinstance(result, Error):
|
||||
self.set_failed(result)
|
||||
return
|
||||
elif result:
|
||||
self.set_state(VmState.running)
|
||||
|
||||
self.save()
|
||||
|
||||
def stopping(self) -> None:
|
||||
config = InstanceConfig.fetch()
|
||||
vm = self.get_vm(config)
|
||||
if not vm.is_deleted():
|
||||
logging.info(PROXY_LOG_PREFIX + "stopping proxy: %s", self.region)
|
||||
vm.delete()
|
||||
self.save()
|
||||
else:
|
||||
self.stopped()
|
||||
|
||||
def stopped(self) -> None:
|
||||
self.set_state(VmState.stopped)
|
||||
logging.info(PROXY_LOG_PREFIX + "removing proxy: %s", self.region)
|
||||
send_event(EventProxyDeleted(region=self.region, proxy_id=self.proxy_id))
|
||||
self.delete()
|
||||
|
||||
def is_outdated(self) -> bool:
|
||||
if self.state not in VmState.available():
|
||||
return True
|
||||
|
||||
if self.version != __version__:
|
||||
logging.info(
|
||||
PROXY_LOG_PREFIX + "mismatch version: proxy:%s service:%s state:%s",
|
||||
self.version,
|
||||
__version__,
|
||||
self.state,
|
||||
)
|
||||
return True
|
||||
if self.created_timestamp is not None:
|
||||
proxy_timestamp = self.created_timestamp
|
||||
if proxy_timestamp < (
|
||||
datetime.datetime.now(tz=datetime.timezone.utc) - PROXY_LIFESPAN
|
||||
):
|
||||
logging.info(
|
||||
PROXY_LOG_PREFIX
|
||||
+ "proxy older than 7 days:proxy-created:%s state:%s",
|
||||
self.created_timestamp,
|
||||
self.state,
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_used(self) -> bool:
|
||||
if len(self.get_forwards()) == 0:
|
||||
logging.info(PROXY_LOG_PREFIX + "no forwards: %s", self.region)
|
||||
return False
|
||||
return True
|
||||
|
||||
def is_alive(self) -> bool:
|
||||
# Unfortunately, with and without TZ information is required for compare
|
||||
# or exceptions are generated
|
||||
ten_minutes_ago_no_tz = datetime.datetime.utcnow() - datetime.timedelta(
|
||||
minutes=10
|
||||
)
|
||||
ten_minutes_ago = ten_minutes_ago_no_tz.astimezone(datetime.timezone.utc)
|
||||
if (
|
||||
self.heartbeat is not None
|
||||
and self.heartbeat.timestamp < ten_minutes_ago_no_tz
|
||||
):
|
||||
logging.error(
|
||||
PROXY_LOG_PREFIX + "last heartbeat is more than an 10 minutes old: "
|
||||
"%s - last heartbeat:%s compared_to:%s",
|
||||
self.region,
|
||||
self.heartbeat,
|
||||
ten_minutes_ago_no_tz,
|
||||
)
|
||||
return False
|
||||
|
||||
elif not self.heartbeat and self.timestamp and self.timestamp < ten_minutes_ago:
|
||||
logging.error(
|
||||
PROXY_LOG_PREFIX + "no heartbeat in the last 10 minutes: "
|
||||
"%s timestamp: %s compared_to:%s",
|
||||
self.region,
|
||||
self.timestamp,
|
||||
ten_minutes_ago,
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_forwards(self) -> List[Forward]:
|
||||
forwards: List[Forward] = []
|
||||
for entry in ProxyForward.search_forward(
|
||||
region=self.region, proxy_id=self.proxy_id
|
||||
):
|
||||
if entry.endtime < datetime.datetime.now(tz=datetime.timezone.utc):
|
||||
entry.delete()
|
||||
else:
|
||||
forwards.append(
|
||||
Forward(
|
||||
src_port=entry.port,
|
||||
dst_ip=entry.dst_ip,
|
||||
dst_port=entry.dst_port,
|
||||
)
|
||||
)
|
||||
return forwards
|
||||
|
||||
def save_proxy_config(self) -> None:
|
||||
forwards = self.get_forwards()
|
||||
proxy_config = ProxyConfig(
|
||||
url=get_file_sas_url(
|
||||
Container("proxy-configs"),
|
||||
"%s/%s/config.json" % (self.region, self.proxy_id),
|
||||
StorageType.config,
|
||||
read=True,
|
||||
),
|
||||
notification=get_queue_sas(
|
||||
"proxy",
|
||||
StorageType.config,
|
||||
add=True,
|
||||
),
|
||||
forwards=forwards,
|
||||
region=self.region,
|
||||
proxy_id=self.proxy_id,
|
||||
instance_telemetry_key=os.environ.get("APPINSIGHTS_INSTRUMENTATIONKEY"),
|
||||
microsoft_telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"),
|
||||
instance_id=get_instance_id(),
|
||||
)
|
||||
|
||||
save_blob(
|
||||
Container("proxy-configs"),
|
||||
"%s/%s/config.json" % (self.region, self.proxy_id),
|
||||
proxy_config.json(),
|
||||
StorageType.config,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def search_states(cls, *, states: Optional[List[VmState]] = None) -> List["Proxy"]:
|
||||
query: QueryFilter = {}
|
||||
if states:
|
||||
query["state"] = states
|
||||
return cls.search(query=query)
|
||||
|
||||
@classmethod
|
||||
def get_or_create(cls, region: Region) -> Optional["Proxy"]:
|
||||
proxy_list = Proxy.search(query={"region": [region], "outdated": [False]})
|
||||
for proxy in proxy_list:
|
||||
if proxy.is_outdated():
|
||||
proxy.outdated = True
|
||||
proxy.save()
|
||||
continue
|
||||
if proxy.state not in VmState.available():
|
||||
continue
|
||||
return proxy
|
||||
|
||||
logging.info(PROXY_LOG_PREFIX + "creating proxy: region:%s", region)
|
||||
proxy = Proxy(region=region)
|
||||
proxy.save()
|
||||
send_event(EventProxyCreated(region=region, proxy_id=proxy.proxy_id))
|
||||
return proxy
|
||||
|
||||
def set_state(self, state: VmState) -> None:
|
||||
if self.state == state:
|
||||
return
|
||||
|
||||
self.state = state
|
||||
self.save()
|
||||
|
||||
send_event(
|
||||
EventProxyStateUpdated(
|
||||
region=self.region, proxy_id=self.proxy_id, state=self.state
|
||||
)
|
||||
)
|
@ -1,143 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from uuid import UUID
|
||||
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error, Forward
|
||||
from onefuzztypes.primitives import Region
|
||||
from pydantic import Field
|
||||
|
||||
from .azure.ip import get_scaleset_instance_ip
|
||||
from .orm import ORMMixin, QueryFilter
|
||||
|
||||
PORT_RANGES = range(28000, 32000)
|
||||
|
||||
|
||||
# This isn't intended to ever be shared to the client, hence not being in
|
||||
# onefuzztypes
|
||||
class ProxyForward(ORMMixin):
|
||||
region: Region
|
||||
port: int
|
||||
scaleset_id: UUID
|
||||
machine_id: UUID
|
||||
proxy_id: Optional[UUID]
|
||||
dst_ip: str
|
||||
dst_port: int
|
||||
endtime: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, str]:
|
||||
return ("region", "port")
|
||||
|
||||
@classmethod
|
||||
def update_or_create(
|
||||
cls,
|
||||
region: Region,
|
||||
scaleset_id: UUID,
|
||||
machine_id: UUID,
|
||||
dst_port: int,
|
||||
duration: int,
|
||||
) -> Union["ProxyForward", Error]:
|
||||
private_ip = get_scaleset_instance_ip(scaleset_id, machine_id)
|
||||
if not private_ip:
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_PORT_FORWARD, errors=["no private ip for node"]
|
||||
)
|
||||
|
||||
entries = cls.search_forward(
|
||||
scaleset_id=scaleset_id,
|
||||
machine_id=machine_id,
|
||||
dst_port=dst_port,
|
||||
region=region,
|
||||
)
|
||||
if entries:
|
||||
entry = entries[0]
|
||||
entry.endtime = datetime.datetime.utcnow() + datetime.timedelta(
|
||||
hours=duration
|
||||
)
|
||||
entry.save()
|
||||
return entry
|
||||
|
||||
existing = [int(x.port) for x in entries]
|
||||
for port in PORT_RANGES:
|
||||
if port in existing:
|
||||
continue
|
||||
|
||||
entry = cls(
|
||||
region=region,
|
||||
port=port,
|
||||
scaleset_id=scaleset_id,
|
||||
machine_id=machine_id,
|
||||
dst_ip=private_ip,
|
||||
dst_port=dst_port,
|
||||
endtime=datetime.datetime.utcnow() + datetime.timedelta(hours=duration),
|
||||
)
|
||||
result = entry.save(new=True)
|
||||
if isinstance(result, Error):
|
||||
logging.info("port is already used: %s", entry)
|
||||
continue
|
||||
|
||||
return entry
|
||||
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_PORT_FORWARD, errors=["all forward ports used"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def remove_forward(
|
||||
cls,
|
||||
scaleset_id: UUID,
|
||||
*,
|
||||
proxy_id: Optional[UUID] = None,
|
||||
machine_id: Optional[UUID] = None,
|
||||
dst_port: Optional[int] = None,
|
||||
) -> List[Region]:
|
||||
entries = cls.search_forward(
|
||||
scaleset_id=scaleset_id,
|
||||
machine_id=machine_id,
|
||||
proxy_id=proxy_id,
|
||||
dst_port=dst_port,
|
||||
)
|
||||
regions = set()
|
||||
for entry in entries:
|
||||
regions.add(entry.region)
|
||||
entry.delete()
|
||||
return list(regions)
|
||||
|
||||
@classmethod
|
||||
def search_forward(
|
||||
cls,
|
||||
*,
|
||||
scaleset_id: Optional[UUID] = None,
|
||||
region: Optional[Region] = None,
|
||||
machine_id: Optional[UUID] = None,
|
||||
proxy_id: Optional[UUID] = None,
|
||||
dst_port: Optional[int] = None,
|
||||
) -> List["ProxyForward"]:
|
||||
|
||||
query: QueryFilter = {}
|
||||
if region is not None:
|
||||
query["region"] = [region]
|
||||
|
||||
if scaleset_id is not None:
|
||||
query["scaleset_id"] = [scaleset_id]
|
||||
|
||||
if machine_id is not None:
|
||||
query["machine_id"] = [machine_id]
|
||||
|
||||
if proxy_id is not None:
|
||||
query["proxy_id"] = [proxy_id]
|
||||
|
||||
if dst_port is not None:
|
||||
query["dst_port"] = [dst_port]
|
||||
|
||||
return cls.search(query=query)
|
||||
|
||||
def to_forward(self) -> Forward:
|
||||
return Forward(src_port=self.port, dst_ip=self.dst_ip, dst_port=self.dst_port)
|
@ -1,165 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from sys import getsizeof
|
||||
from typing import Optional, Union
|
||||
|
||||
from memoization import cached
|
||||
from onefuzztypes.models import RegressionReport, Report
|
||||
from onefuzztypes.primitives import Container
|
||||
from pydantic import ValidationError
|
||||
|
||||
from .azure.containers import get_blob
|
||||
from .azure.storage import StorageType
|
||||
|
||||
|
||||
# This is fix for the following error:
|
||||
# Exception while executing function:
|
||||
# Functions.queue_file_changes Result: Failure
|
||||
# Exception: AzureHttpError: Bad Request
|
||||
# "The property value exceeds the maximum allowed size (64KB).
|
||||
# If the property value is a string, it is UTF-16 encoded and
|
||||
# the maximum number of characters should be 32K or less.
|
||||
def fix_report_size(
|
||||
content: str,
|
||||
report: Report,
|
||||
acceptable_report_length_kb: int = 24,
|
||||
keep_num_entries: int = 10,
|
||||
keep_string_len: int = 256,
|
||||
) -> Report:
|
||||
logging.info(f"report content length {getsizeof(content)}")
|
||||
if getsizeof(content) > acceptable_report_length_kb * 1024:
|
||||
msg = f"report data exceeds {acceptable_report_length_kb}K {getsizeof(content)}"
|
||||
if len(report.call_stack) > keep_num_entries:
|
||||
msg = msg + "; removing some of stack frames from the report"
|
||||
report.call_stack = report.call_stack[0:keep_num_entries] + ["..."]
|
||||
|
||||
if report.asan_log and len(report.asan_log) > keep_string_len:
|
||||
msg = msg + "; removing some of asan log entries from the report"
|
||||
report.asan_log = report.asan_log[0:keep_string_len] + "..."
|
||||
|
||||
if report.minimized_stack and len(report.minimized_stack) > keep_num_entries:
|
||||
msg = msg + "; removing some of minimized stack frames from the report"
|
||||
report.minimized_stack = report.minimized_stack[0:keep_num_entries] + [
|
||||
"..."
|
||||
]
|
||||
|
||||
if (
|
||||
report.minimized_stack_function_names
|
||||
and len(report.minimized_stack_function_names) > keep_num_entries
|
||||
):
|
||||
msg = (
|
||||
msg
|
||||
+ "; removing some of minimized stack function names from the report"
|
||||
)
|
||||
report.minimized_stack_function_names = (
|
||||
report.minimized_stack_function_names[0:keep_num_entries] + ["..."]
|
||||
)
|
||||
|
||||
if (
|
||||
report.minimized_stack_function_lines
|
||||
and len(report.minimized_stack_function_lines) > keep_num_entries
|
||||
):
|
||||
msg = (
|
||||
msg
|
||||
+ "; removing some of minimized stack function lines from the report"
|
||||
)
|
||||
report.minimized_stack_function_lines = (
|
||||
report.minimized_stack_function_lines[0:keep_num_entries] + ["..."]
|
||||
)
|
||||
|
||||
logging.info(msg)
|
||||
return report
|
||||
|
||||
|
||||
def parse_report_or_regression(
|
||||
content: Union[str, bytes],
|
||||
file_path: Optional[str] = None,
|
||||
expect_reports: bool = False,
|
||||
) -> Optional[Union[Report, RegressionReport]]:
|
||||
if isinstance(content, bytes):
|
||||
try:
|
||||
content = content.decode()
|
||||
except UnicodeDecodeError as err:
|
||||
if expect_reports:
|
||||
logging.error(
|
||||
f"unable to parse report ({file_path}): "
|
||||
f"unicode decode of report failed - {err}"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
data = json.loads(content)
|
||||
except json.decoder.JSONDecodeError as err:
|
||||
if expect_reports:
|
||||
logging.error(
|
||||
f"unable to parse report ({file_path}): json decoding failed - {err}"
|
||||
)
|
||||
return None
|
||||
|
||||
regression_err = None
|
||||
try:
|
||||
regression_report = RegressionReport.parse_obj(data)
|
||||
|
||||
if (
|
||||
regression_report.crash_test_result is not None
|
||||
and regression_report.crash_test_result.crash_report is not None
|
||||
):
|
||||
regression_report.crash_test_result.crash_report = fix_report_size(
|
||||
content, regression_report.crash_test_result.crash_report
|
||||
)
|
||||
|
||||
if (
|
||||
regression_report.original_crash_test_result is not None
|
||||
and regression_report.original_crash_test_result.crash_report is not None
|
||||
):
|
||||
regression_report.original_crash_test_result.crash_report = fix_report_size(
|
||||
content, regression_report.original_crash_test_result.crash_report
|
||||
)
|
||||
return regression_report
|
||||
except ValidationError as err:
|
||||
regression_err = err
|
||||
|
||||
try:
|
||||
report = Report.parse_obj(data)
|
||||
return fix_report_size(content, report)
|
||||
except ValidationError as err:
|
||||
if expect_reports:
|
||||
logging.error(
|
||||
f"unable to parse report ({file_path}) as a report or regression. "
|
||||
f"regression error: {regression_err} report error: {err}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# cache the last 1000 reports
|
||||
@cached(max_size=1000)
|
||||
def get_report_or_regression(
|
||||
container: Container, filename: str, *, expect_reports: bool = False
|
||||
) -> Optional[Union[Report, RegressionReport]]:
|
||||
file_path = "/".join([container, filename])
|
||||
if not filename.endswith(".json"):
|
||||
if expect_reports:
|
||||
logging.error("get_report invalid extension: %s", file_path)
|
||||
return None
|
||||
|
||||
blob = get_blob(container, filename, StorageType.corpus)
|
||||
if blob is None:
|
||||
if expect_reports:
|
||||
logging.error("get_report invalid blob: %s", file_path)
|
||||
return None
|
||||
|
||||
return parse_report_or_regression(
|
||||
blob, file_path=file_path, expect_reports=expect_reports
|
||||
)
|
||||
|
||||
|
||||
def get_report(container: Container, filename: str) -> Optional[Report]:
|
||||
result = get_report_or_regression(container, filename)
|
||||
if isinstance(result, Report):
|
||||
return result
|
||||
return None
|
@ -1,286 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from azure.mgmt.compute.models import VirtualMachine
|
||||
from onefuzztypes.enums import OS, ContainerType, ErrorCode, VmState
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.models import Repro as BASE_REPRO
|
||||
from onefuzztypes.models import ReproConfig, TaskVm, UserInfo
|
||||
from onefuzztypes.primitives import Container
|
||||
|
||||
from .azure.auth import build_auth
|
||||
from .azure.containers import save_blob
|
||||
from .azure.creds import get_base_region
|
||||
from .azure.ip import get_public_ip
|
||||
from .azure.nsg import NSG
|
||||
from .azure.storage import StorageType
|
||||
from .azure.vm import VM
|
||||
from .config import InstanceConfig
|
||||
from .extension import repro_extensions
|
||||
from .orm import ORMMixin, QueryFilter
|
||||
from .reports import get_report
|
||||
from .tasks.main import Task
|
||||
|
||||
DEFAULT_SKU = "Standard_DS1_v2"
|
||||
|
||||
|
||||
class Repro(BASE_REPRO, ORMMixin):
|
||||
def set_error(self, error: Error) -> None:
|
||||
logging.error(
|
||||
"repro failed: vm_id: %s task_id: %s: error: %s",
|
||||
self.vm_id,
|
||||
self.task_id,
|
||||
error,
|
||||
)
|
||||
self.error = error
|
||||
self.state = VmState.stopping
|
||||
self.save()
|
||||
|
||||
def get_vm(self, config: InstanceConfig) -> VM:
|
||||
tags = None
|
||||
if config.vm_tags:
|
||||
tags = config.vm_tags
|
||||
|
||||
task = Task.get_by_task_id(self.task_id)
|
||||
if isinstance(task, Error):
|
||||
raise Exception("previously existing task missing: %s" % self.task_id)
|
||||
|
||||
config = InstanceConfig.fetch()
|
||||
default_os = {
|
||||
OS.linux: config.default_linux_vm_image,
|
||||
OS.windows: config.default_windows_vm_image,
|
||||
}
|
||||
vm_config = task.get_repro_vm_config()
|
||||
if vm_config is None:
|
||||
# if using a pool without any scalesets defined yet, use reasonable defaults
|
||||
if task.os not in default_os:
|
||||
raise NotImplementedError("unsupported OS for repro %s" % task.os)
|
||||
|
||||
vm_config = TaskVm(
|
||||
region=get_base_region(), sku=DEFAULT_SKU, image=default_os[task.os]
|
||||
)
|
||||
|
||||
if self.auth is None:
|
||||
raise Exception("missing auth")
|
||||
|
||||
return VM(
|
||||
name=self.vm_id,
|
||||
region=vm_config.region,
|
||||
sku=vm_config.sku,
|
||||
image=vm_config.image,
|
||||
auth=self.auth,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
def init(self) -> None:
|
||||
config = InstanceConfig.fetch()
|
||||
vm = self.get_vm(config)
|
||||
vm_data = vm.get()
|
||||
if vm_data:
|
||||
if vm_data.provisioning_state == "Failed":
|
||||
self.set_failed(vm)
|
||||
else:
|
||||
script_result = self.build_repro_script()
|
||||
if isinstance(script_result, Error):
|
||||
self.set_error(script_result)
|
||||
return
|
||||
|
||||
self.state = VmState.extensions_launch
|
||||
else:
|
||||
nsg = NSG(
|
||||
name=vm.region,
|
||||
region=vm.region,
|
||||
)
|
||||
result = nsg.create()
|
||||
if isinstance(result, Error):
|
||||
self.set_failed(result)
|
||||
return
|
||||
|
||||
nsg_config = config.proxy_nsg_config
|
||||
result = nsg.set_allowed_sources(nsg_config)
|
||||
if isinstance(result, Error):
|
||||
self.set_failed(result)
|
||||
return
|
||||
|
||||
vm.nsg = nsg
|
||||
result = vm.create()
|
||||
if isinstance(result, Error):
|
||||
self.set_error(result)
|
||||
return
|
||||
self.save()
|
||||
|
||||
def set_failed(self, vm_data: VirtualMachine) -> None:
|
||||
errors = []
|
||||
for status in vm_data.instance_view.statuses:
|
||||
if status.level.name.lower() == "error":
|
||||
errors.append(
|
||||
"%s %s %s" % (status.code, status.display_status, status.message)
|
||||
)
|
||||
return self.set_error(Error(code=ErrorCode.VM_CREATE_FAILED, errors=errors))
|
||||
|
||||
def get_setup_container(self) -> Optional[Container]:
|
||||
task = Task.get_by_task_id(self.task_id)
|
||||
if isinstance(task, Task):
|
||||
for container in task.config.containers:
|
||||
if container.type == ContainerType.setup:
|
||||
return container.name
|
||||
return None
|
||||
|
||||
def extensions_launch(self) -> None:
|
||||
config = InstanceConfig.fetch()
|
||||
vm = self.get_vm(config)
|
||||
vm_data = vm.get()
|
||||
if not vm_data:
|
||||
self.set_error(
|
||||
Error(
|
||||
code=ErrorCode.VM_CREATE_FAILED,
|
||||
errors=["failed before launching extensions"],
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if vm_data.provisioning_state == "Failed":
|
||||
self.set_failed(vm_data)
|
||||
return
|
||||
|
||||
if not self.ip:
|
||||
self.ip = get_public_ip(vm_data.network_profile.network_interfaces[0].id)
|
||||
|
||||
extensions = repro_extensions(
|
||||
vm.region, self.os, self.vm_id, self.config, self.get_setup_container()
|
||||
)
|
||||
result = vm.add_extensions(extensions)
|
||||
if isinstance(result, Error):
|
||||
self.set_error(result)
|
||||
return
|
||||
elif result:
|
||||
self.state = VmState.running
|
||||
|
||||
self.save()
|
||||
|
||||
def stopping(self) -> None:
|
||||
config = InstanceConfig.fetch()
|
||||
vm = self.get_vm(config)
|
||||
if not vm.is_deleted():
|
||||
logging.info("vm stopping: %s", self.vm_id)
|
||||
vm.delete()
|
||||
self.save()
|
||||
else:
|
||||
self.stopped()
|
||||
|
||||
def stopped(self) -> None:
|
||||
logging.info("vm stopped: %s", self.vm_id)
|
||||
self.delete()
|
||||
|
||||
def build_repro_script(self) -> Optional[Error]:
|
||||
if self.auth is None:
|
||||
return Error(code=ErrorCode.VM_CREATE_FAILED, errors=["missing auth"])
|
||||
|
||||
task = Task.get_by_task_id(self.task_id)
|
||||
if isinstance(task, Error):
|
||||
return task
|
||||
|
||||
report = get_report(self.config.container, self.config.path)
|
||||
if report is None:
|
||||
return Error(code=ErrorCode.VM_CREATE_FAILED, errors=["missing report"])
|
||||
|
||||
if report.input_blob is None:
|
||||
return Error(
|
||||
code=ErrorCode.VM_CREATE_FAILED,
|
||||
errors=["unable to perform repro for crash reports without inputs"],
|
||||
)
|
||||
|
||||
files = {}
|
||||
|
||||
if task.os == OS.windows:
|
||||
ssh_path = "$env:ProgramData/ssh/administrators_authorized_keys"
|
||||
cmds = [
|
||||
'Set-Content -Path %s -Value "%s"' % (ssh_path, self.auth.public_key),
|
||||
". C:\\onefuzz\\tools\\win64\\onefuzz.ps1",
|
||||
"Set-SetSSHACL",
|
||||
'while (1) { cdb -server tcp:port=1337 -c "g" setup\\%s %s }'
|
||||
% (
|
||||
task.config.task.target_exe,
|
||||
report.input_blob.name,
|
||||
),
|
||||
]
|
||||
cmd = "\r\n".join(cmds)
|
||||
files["repro.ps1"] = cmd
|
||||
elif task.os == OS.linux:
|
||||
gdb_fmt = (
|
||||
"ASAN_OPTIONS='abort_on_error=1' gdbserver "
|
||||
"%s /onefuzz/setup/%s /onefuzz/downloaded/%s"
|
||||
)
|
||||
cmd = "while :; do %s; done" % (
|
||||
gdb_fmt
|
||||
% (
|
||||
"localhost:1337",
|
||||
task.config.task.target_exe,
|
||||
report.input_blob.name,
|
||||
)
|
||||
)
|
||||
files["repro.sh"] = cmd
|
||||
|
||||
cmd = "#!/bin/bash\n%s" % (
|
||||
gdb_fmt % ("-", task.config.task.target_exe, report.input_blob.name)
|
||||
)
|
||||
files["repro-stdout.sh"] = cmd
|
||||
else:
|
||||
raise NotImplementedError("invalid task os: %s" % task.os)
|
||||
|
||||
for filename in files:
|
||||
save_blob(
|
||||
Container("repro-scripts"),
|
||||
"%s/%s" % (self.vm_id, filename),
|
||||
files[filename],
|
||||
StorageType.config,
|
||||
)
|
||||
|
||||
logging.info("saved repro script")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def search_states(cls, *, states: Optional[List[VmState]] = None) -> List["Repro"]:
|
||||
query: QueryFilter = {}
|
||||
if states:
|
||||
query["state"] = states
|
||||
return cls.search(query=query)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls, config: ReproConfig, user_info: Optional[UserInfo]
|
||||
) -> Union[Error, "Repro"]:
|
||||
report = get_report(config.container, config.path)
|
||||
if not report:
|
||||
return Error(
|
||||
code=ErrorCode.UNABLE_TO_FIND, errors=["unable to find report"]
|
||||
)
|
||||
|
||||
task = Task.get_by_task_id(report.task_id)
|
||||
if isinstance(task, Error):
|
||||
return task
|
||||
|
||||
vm = cls(config=config, task_id=task.task_id, os=task.os, auth=build_auth())
|
||||
if vm.end_time is None:
|
||||
vm.end_time = datetime.utcnow() + timedelta(hours=config.duration)
|
||||
|
||||
vm.user_info = user_info
|
||||
vm.save()
|
||||
|
||||
return vm
|
||||
|
||||
@classmethod
|
||||
def search_expired(cls) -> List["Repro"]:
|
||||
# unlike jobs/tasks, the entry is deleted from the backing table upon stop
|
||||
time_filter = "end_time lt datetime'%s'" % datetime.utcnow().isoformat()
|
||||
return cls.search(raw_unchecked_filter=time_filter)
|
||||
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, Optional[str]]:
|
||||
return ("vm_id", None)
|
@ -1,118 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Sequence, Type, TypeVar, Union
|
||||
from uuid import UUID
|
||||
|
||||
from azure.functions import HttpRequest, HttpResponse
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.responses import BaseResponse
|
||||
from pydantic import BaseModel # noqa: F401
|
||||
from pydantic import ValidationError
|
||||
|
||||
from .orm import ModelMixin
|
||||
|
||||
# We don't actually use these types at runtime at this time. Rather,
|
||||
# these are used in a bound TypeVar. MyPy suggests to only import these
|
||||
# types during type checking.
|
||||
if TYPE_CHECKING:
|
||||
from onefuzztypes.requests import BaseRequest # noqa: F401
|
||||
|
||||
|
||||
def ok(
|
||||
data: Union[BaseResponse, Sequence[BaseResponse], ModelMixin, Sequence[ModelMixin]]
|
||||
) -> HttpResponse:
|
||||
if isinstance(data, BaseResponse):
|
||||
return HttpResponse(data.json(exclude_none=True), mimetype="application/json")
|
||||
|
||||
if isinstance(data, list) and len(data) > 0 and isinstance(data[0], BaseResponse):
|
||||
decoded = [json.loads(x.json(exclude_none=True)) for x in data]
|
||||
return HttpResponse(json.dumps(decoded), mimetype="application/json")
|
||||
|
||||
if isinstance(data, ModelMixin):
|
||||
return HttpResponse(
|
||||
data.json(exclude_none=True, exclude=data.export_exclude()),
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
decoded = [
|
||||
x.raw(exclude_none=True, exclude=x.export_exclude())
|
||||
if isinstance(x, ModelMixin)
|
||||
else x
|
||||
for x in data
|
||||
]
|
||||
return HttpResponse(
|
||||
json.dumps(decoded),
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
|
||||
def not_ok(
|
||||
error: Error, *, status_code: int = 400, context: Union[str, UUID]
|
||||
) -> HttpResponse:
|
||||
if 400 <= status_code <= 599:
|
||||
logging.error("request error - %s: %s", str(context), error.json())
|
||||
|
||||
return HttpResponse(
|
||||
error.json(), status_code=status_code, mimetype="application/json"
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"status code %s is not int the expected range [400; 599]" % status_code
|
||||
)
|
||||
|
||||
|
||||
def redirect(url: str) -> HttpResponse:
|
||||
return HttpResponse(status_code=302, headers={"Location": url})
|
||||
|
||||
|
||||
def convert_error(err: ValidationError) -> Error:
|
||||
errors = []
|
||||
for error in err.errors():
|
||||
if isinstance(error["loc"], tuple):
|
||||
name = ".".join([str(x) for x in error["loc"]])
|
||||
else:
|
||||
name = str(error["loc"])
|
||||
errors.append("%s: %s" % (name, error["msg"]))
|
||||
return Error(code=ErrorCode.INVALID_REQUEST, errors=errors)
|
||||
|
||||
|
||||
# TODO: loosen restrictions here during dev. We should be specific
|
||||
# about only parsing things that are of a "Request" type, but there are
|
||||
# a handful of types that need work in order to enforce that.
|
||||
#
|
||||
# These can be easily found by swapping the following comment and running
|
||||
# mypy.
|
||||
#
|
||||
# A = TypeVar("A", bound="BaseRequest")
|
||||
A = TypeVar("A", bound="BaseModel")
|
||||
|
||||
|
||||
def parse_request(cls: Type[A], req: HttpRequest) -> Union[A, Error]:
|
||||
try:
|
||||
return cls.parse_obj(req.get_json())
|
||||
except ValidationError as err:
|
||||
return convert_error(err)
|
||||
|
||||
|
||||
def parse_uri(cls: Type[A], req: HttpRequest) -> Union[A, Error]:
|
||||
data = {}
|
||||
for key in req.params:
|
||||
data[key] = req.params[key]
|
||||
|
||||
try:
|
||||
return cls.parse_obj(data)
|
||||
except ValidationError as err:
|
||||
return convert_error(err)
|
||||
|
||||
|
||||
class RequestException(Exception):
|
||||
def __init__(self, error: Error):
|
||||
self.error = error
|
||||
message = "error %s" % error
|
||||
super().__init__(message)
|
@ -1,109 +0,0 @@
|
||||
from typing import Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from onefuzztypes.models import ApiAccessRule
|
||||
|
||||
|
||||
class RuleConflictError(Exception):
|
||||
def __init__(self, message: str) -> None:
|
||||
super(RuleConflictError, self).__init__(message)
|
||||
self.message = message
|
||||
|
||||
|
||||
class RequestAccess:
|
||||
"""
|
||||
Stores the rules associated with a the request paths
|
||||
"""
|
||||
|
||||
class Rules:
|
||||
allowed_groups_ids: List[UUID]
|
||||
|
||||
def __init__(self, allowed_groups_ids: List[UUID] = []) -> None:
|
||||
self.allowed_groups_ids = allowed_groups_ids
|
||||
|
||||
class Node:
|
||||
# http method -> rules
|
||||
rules: Dict[str, "RequestAccess.Rules"]
|
||||
# path -> node
|
||||
children: Dict[str, "RequestAccess.Node"]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.rules = {}
|
||||
self.children = {}
|
||||
pass
|
||||
|
||||
root: Node
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.root = RequestAccess.Node()
|
||||
|
||||
def __add_url__(self, methods: List[str], path: str, rules: Rules) -> None:
|
||||
methods = list(map(lambda m: m.upper(), methods))
|
||||
|
||||
segments = [s for s in path.split("/") if s != ""]
|
||||
if len(segments) == 0:
|
||||
return
|
||||
|
||||
current_node = self.root
|
||||
current_segment_index = 0
|
||||
|
||||
while current_segment_index < len(segments):
|
||||
current_segment = segments[current_segment_index]
|
||||
if current_segment in current_node.children:
|
||||
current_node = current_node.children[current_segment]
|
||||
current_segment_index = current_segment_index + 1
|
||||
else:
|
||||
break
|
||||
# we found a node matching this exact path
|
||||
# This means that there is an existing rule causing a conflict
|
||||
if current_segment_index == len(segments):
|
||||
for method in methods:
|
||||
if method in current_node.rules:
|
||||
raise RuleConflictError(f"Conflicting rules on {method} {path}")
|
||||
|
||||
while current_segment_index < len(segments):
|
||||
current_segment = segments[current_segment_index]
|
||||
current_node.children[current_segment] = RequestAccess.Node()
|
||||
current_node = current_node.children[current_segment]
|
||||
current_segment_index = current_segment_index + 1
|
||||
|
||||
for method in methods:
|
||||
current_node.rules[method] = rules
|
||||
|
||||
def get_matching_rules(self, method: str, path: str) -> Optional[Rules]:
|
||||
method = method.upper()
|
||||
segments = [s for s in path.split("/") if s != ""]
|
||||
current_node = self.root
|
||||
current_rule = None
|
||||
|
||||
if method in current_node.rules:
|
||||
current_rule = current_node.rules[method]
|
||||
|
||||
current_segment_index = 0
|
||||
|
||||
while current_segment_index < len(segments):
|
||||
current_segment = segments[current_segment_index]
|
||||
if current_segment in current_node.children:
|
||||
current_node = current_node.children[current_segment]
|
||||
elif "*" in current_node.children:
|
||||
current_node = current_node.children["*"]
|
||||
else:
|
||||
break
|
||||
|
||||
if method in current_node.rules:
|
||||
current_rule = current_node.rules[method]
|
||||
current_segment_index = current_segment_index + 1
|
||||
return current_rule
|
||||
|
||||
@classmethod
|
||||
def build(cls, rules: Dict[str, ApiAccessRule]) -> "RequestAccess":
|
||||
request_access = RequestAccess()
|
||||
for endpoint in rules:
|
||||
rule = rules[endpoint]
|
||||
request_access.__add_url__(
|
||||
rule.methods,
|
||||
endpoint,
|
||||
RequestAccess.Rules(allowed_groups_ids=rule.allowed_groups),
|
||||
)
|
||||
|
||||
return request_access
|
@ -1,87 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import os
|
||||
from typing import Tuple, Type, TypeVar
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
from azure.keyvault.secrets import KeyVaultSecret
|
||||
from onefuzztypes.models import SecretAddress, SecretData
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .azure.creds import get_keyvault_client
|
||||
|
||||
A = TypeVar("A", bound=BaseModel)
|
||||
|
||||
|
||||
def save_to_keyvault(secret_data: SecretData) -> SecretData:
|
||||
if isinstance(secret_data.secret, SecretAddress):
|
||||
return secret_data
|
||||
|
||||
secret_name = str(uuid4())
|
||||
if isinstance(secret_data.secret, str):
|
||||
secret_value = secret_data.secret
|
||||
elif isinstance(secret_data.secret, BaseModel):
|
||||
secret_value = secret_data.secret.json()
|
||||
else:
|
||||
raise Exception("invalid secret data")
|
||||
|
||||
kv = store_in_keyvault(get_keyvault_address(), secret_name, secret_value)
|
||||
secret_data.secret = SecretAddress(url=kv.id)
|
||||
return secret_data
|
||||
|
||||
|
||||
def get_secret_string_value(self: SecretData[str]) -> str:
|
||||
if isinstance(self.secret, SecretAddress):
|
||||
secret = get_secret(self.secret.url)
|
||||
return secret.value
|
||||
else:
|
||||
return self.secret
|
||||
|
||||
|
||||
def get_keyvault_address() -> str:
|
||||
# https://docs.microsoft.com/en-us/azure/key-vault/general/about-keys-secrets-certificates#vault-name-and-object-name
|
||||
keyvault_name = os.environ["ONEFUZZ_KEYVAULT"]
|
||||
return f"https://{keyvault_name}.vault.azure.net"
|
||||
|
||||
|
||||
def store_in_keyvault(
|
||||
keyvault_url: str, secret_name: str, secret_value: str
|
||||
) -> KeyVaultSecret:
|
||||
keyvault_client = get_keyvault_client(keyvault_url)
|
||||
kvs: KeyVaultSecret = keyvault_client.set_secret(secret_name, secret_value)
|
||||
return kvs
|
||||
|
||||
|
||||
def parse_secret_url(secret_url: str) -> Tuple[str, str]:
|
||||
# format: https://{vault-name}.vault.azure.net/secrets/{secret-name}/{version}
|
||||
u = urlparse(secret_url)
|
||||
vault_url = f"{u.scheme}://{u.netloc}"
|
||||
secret_name = u.path.split("/")[2]
|
||||
return (vault_url, secret_name)
|
||||
|
||||
|
||||
def get_secret(secret_url: str) -> KeyVaultSecret:
|
||||
(vault_url, secret_name) = parse_secret_url(secret_url)
|
||||
keyvault_client = get_keyvault_client(vault_url)
|
||||
return keyvault_client.get_secret(secret_name)
|
||||
|
||||
|
||||
def get_secret_obj(secret_url: str, model: Type[A]) -> A:
|
||||
secret = get_secret(secret_url)
|
||||
return model.parse_raw(secret.value)
|
||||
|
||||
|
||||
def delete_secret(secret_url: str) -> None:
|
||||
(vault_url, secret_name) = parse_secret_url(secret_url)
|
||||
keyvault_client = get_keyvault_client(vault_url)
|
||||
keyvault_client.begin_delete_secret(secret_name)
|
||||
|
||||
|
||||
def delete_remote_secret_data(data: SecretData) -> None:
|
||||
if isinstance(data.secret, SecretAddress):
|
||||
delete_secret(data.secret.url)
|
@ -1,52 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import datetime
|
||||
from typing import List, Optional, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from onefuzztypes.models import TaskEvent as BASE_TASK_EVENT
|
||||
from onefuzztypes.models import TaskEventSummary, WorkerEvent
|
||||
|
||||
from .orm import ORMMixin
|
||||
|
||||
|
||||
class TaskEvent(BASE_TASK_EVENT, ORMMixin):
|
||||
@classmethod
|
||||
def get_summary(cls, task_id: UUID) -> List[TaskEventSummary]:
|
||||
events = cls.search(query={"task_id": [task_id]})
|
||||
# handle None case of Optional[e.timestamp], which shouldn't happen
|
||||
events.sort(key=lambda e: e.timestamp or datetime.datetime.max)
|
||||
|
||||
return [
|
||||
TaskEventSummary(
|
||||
timestamp=e.timestamp,
|
||||
event_data=get_event_data(e.event_data),
|
||||
event_type=get_event_type(e.event_data),
|
||||
)
|
||||
for e in events
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, Optional[str]]:
|
||||
return ("task_id", None)
|
||||
|
||||
|
||||
def get_event_data(event: WorkerEvent) -> str:
|
||||
if event.done:
|
||||
return "exit status: %s" % event.done.exit_status
|
||||
elif event.running:
|
||||
return ""
|
||||
else:
|
||||
return "Unrecognized event: %s" % event
|
||||
|
||||
|
||||
def get_event_type(event: WorkerEvent) -> str:
|
||||
if event.done:
|
||||
return type(event.done).__name__
|
||||
elif event.running:
|
||||
return type(event.running).__name__
|
||||
else:
|
||||
return "Unrecognized event: %s" % event
|
@ -1,466 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from onefuzztypes.enums import Compare, ContainerPermission, ContainerType, TaskFeature
|
||||
from onefuzztypes.models import Job, Task, TaskConfig, TaskDefinition, TaskUnitConfig
|
||||
from onefuzztypes.primitives import Container
|
||||
|
||||
from ..azure.containers import (
|
||||
add_container_sas_url,
|
||||
blob_exists,
|
||||
container_exists,
|
||||
get_container_sas_url,
|
||||
)
|
||||
from ..azure.creds import get_instance_id
|
||||
from ..azure.queue import get_queue_sas
|
||||
from ..azure.storage import StorageType
|
||||
from ..workers.pools import Pool
|
||||
from .defs import TASK_DEFINITIONS
|
||||
|
||||
LOGGER = logging.getLogger("onefuzz")
|
||||
|
||||
|
||||
def get_input_container_queues(config: TaskConfig) -> Optional[List[str]]: # tasks.Task
|
||||
|
||||
if config.task.type not in TASK_DEFINITIONS:
|
||||
raise TaskConfigError("unsupported task type: %s" % config.task.type.name)
|
||||
|
||||
container_type = TASK_DEFINITIONS[config.task.type].monitor_queue
|
||||
if container_type:
|
||||
return [x.name for x in config.containers if x.type == container_type]
|
||||
return None
|
||||
|
||||
|
||||
def check_val(compare: Compare, expected: int, actual: int) -> bool:
|
||||
if compare == Compare.Equal:
|
||||
return expected == actual
|
||||
|
||||
if compare == Compare.AtLeast:
|
||||
return expected <= actual
|
||||
|
||||
if compare == Compare.AtMost:
|
||||
return expected >= actual
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def check_container(
|
||||
compare: Compare,
|
||||
expected: int,
|
||||
container_type: ContainerType,
|
||||
containers: Dict[ContainerType, List[Container]],
|
||||
) -> None:
|
||||
actual = len(containers.get(container_type, []))
|
||||
if not check_val(compare, expected, actual):
|
||||
raise TaskConfigError(
|
||||
"container type %s: expected %s %d, got %d"
|
||||
% (container_type.name, compare.name, expected, actual)
|
||||
)
|
||||
|
||||
|
||||
def check_containers(definition: TaskDefinition, config: TaskConfig) -> None:
|
||||
checked = set()
|
||||
|
||||
containers: Dict[ContainerType, List[Container]] = {}
|
||||
for container in config.containers:
|
||||
if container.name not in checked:
|
||||
if not container_exists(container.name, StorageType.corpus):
|
||||
raise TaskConfigError("missing container: %s" % container.name)
|
||||
checked.add(container.name)
|
||||
|
||||
if container.type not in containers:
|
||||
containers[container.type] = []
|
||||
containers[container.type].append(container.name)
|
||||
|
||||
for container_def in definition.containers:
|
||||
check_container(
|
||||
container_def.compare, container_def.value, container_def.type, containers
|
||||
)
|
||||
|
||||
for container_type in containers:
|
||||
if container_type not in [x.type for x in definition.containers]:
|
||||
raise TaskConfigError(
|
||||
"unsupported container type for this task: %s" % container_type.name
|
||||
)
|
||||
|
||||
if definition.monitor_queue:
|
||||
if definition.monitor_queue not in [x.type for x in definition.containers]:
|
||||
raise TaskConfigError(
|
||||
"unable to monitor container type as it is not used by this task: %s"
|
||||
% definition.monitor_queue.name
|
||||
)
|
||||
|
||||
|
||||
def check_target_exe(config: TaskConfig, definition: TaskDefinition) -> None:
|
||||
if config.task.target_exe is None:
|
||||
if TaskFeature.target_exe in definition.features:
|
||||
raise TaskConfigError("missing target_exe")
|
||||
|
||||
if TaskFeature.target_exe_optional in definition.features:
|
||||
return
|
||||
|
||||
return
|
||||
|
||||
# User-submitted paths must be relative to the setup directory that contains them.
|
||||
# They also must be normalized, and exclude special filesystem path elements.
|
||||
#
|
||||
# For example, accessing the blob store path "./foo" generates an exception, but
|
||||
# "foo" and "foo/bar" do not.
|
||||
if not is_valid_blob_name(config.task.target_exe):
|
||||
raise TaskConfigError("target_exe must be a canonicalized relative path")
|
||||
|
||||
container = [x for x in config.containers if x.type == ContainerType.setup][0]
|
||||
if not blob_exists(container.name, config.task.target_exe, StorageType.corpus):
|
||||
err = "target_exe `%s` does not exist in the setup container `%s`" % (
|
||||
config.task.target_exe,
|
||||
container.name,
|
||||
)
|
||||
LOGGER.warning(err)
|
||||
|
||||
|
||||
# Azure Blob Storage uses a flat scheme, and has no true directory hierarchy. Forward
|
||||
# slashes are used to delimit a _virtual_ directory structure.
|
||||
def is_valid_blob_name(blob_name: str) -> bool:
|
||||
# https://docs.microsoft.com/en-us/rest/api/storageservices/naming-and-referencing-containers--blobs--and-metadata#blob-names
|
||||
MIN_LENGTH = 1
|
||||
MAX_LENGTH = 1024 # Inclusive
|
||||
MAX_PATH_SEGMENTS = 254
|
||||
|
||||
length = len(blob_name)
|
||||
|
||||
# No leading/trailing whitespace.
|
||||
if blob_name != blob_name.strip():
|
||||
return False
|
||||
|
||||
if length < MIN_LENGTH:
|
||||
return False
|
||||
|
||||
if length > MAX_LENGTH:
|
||||
return False
|
||||
|
||||
path = pathlib.PurePosixPath(blob_name)
|
||||
|
||||
if len(path.parts) > MAX_PATH_SEGMENTS:
|
||||
return False
|
||||
|
||||
# No path segment should end with a dot (`.`).
|
||||
for part in path.parts:
|
||||
if part.endswith("."):
|
||||
return False
|
||||
|
||||
# Reject absolute paths to avoid confusion.
|
||||
if path.is_absolute():
|
||||
return False
|
||||
|
||||
# Reject paths with special relative filesystem entries.
|
||||
if "." in path.parts:
|
||||
return False
|
||||
|
||||
if ".." in path.parts:
|
||||
return False
|
||||
|
||||
# Will not have a leading `.`, even if `blob_name` does.
|
||||
normalized = path.as_posix()
|
||||
|
||||
return blob_name == normalized
|
||||
|
||||
|
||||
def target_uses_input(config: TaskConfig) -> bool:
|
||||
if config.task.target_options is not None:
|
||||
for option in config.task.target_options:
|
||||
if "{input}" in option:
|
||||
return True
|
||||
if config.task.target_env is not None:
|
||||
for value in config.task.target_env.values():
|
||||
if "{input}" in value:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def check_config(config: TaskConfig) -> None:
|
||||
if config.task.type not in TASK_DEFINITIONS:
|
||||
raise TaskConfigError("unsupported task type: %s" % config.task.type.name)
|
||||
|
||||
if config.vm is not None and config.pool is not None:
|
||||
raise TaskConfigError("either the vm or pool must be specified, but not both")
|
||||
|
||||
definition = TASK_DEFINITIONS[config.task.type]
|
||||
|
||||
check_containers(definition, config)
|
||||
|
||||
if (
|
||||
TaskFeature.supervisor_exe in definition.features
|
||||
and not config.task.supervisor_exe
|
||||
):
|
||||
err = "missing supervisor_exe"
|
||||
LOGGER.error(err)
|
||||
raise TaskConfigError("missing supervisor_exe")
|
||||
|
||||
if (
|
||||
TaskFeature.target_must_use_input in definition.features
|
||||
and not target_uses_input(config)
|
||||
):
|
||||
raise TaskConfigError("{input} must be used in target_env or target_options")
|
||||
|
||||
if config.vm:
|
||||
err = "specifying task config vm is no longer supported"
|
||||
raise TaskConfigError(err)
|
||||
|
||||
if not config.pool:
|
||||
raise TaskConfigError("pool must be specified")
|
||||
|
||||
if not check_val(definition.vm.compare, definition.vm.value, config.pool.count):
|
||||
err = "invalid vm count: expected %s %d, got %s" % (
|
||||
definition.vm.compare,
|
||||
definition.vm.value,
|
||||
config.pool.count,
|
||||
)
|
||||
LOGGER.error(err)
|
||||
raise TaskConfigError(err)
|
||||
|
||||
pool = Pool.get_by_name(config.pool.pool_name)
|
||||
if not isinstance(pool, Pool):
|
||||
raise TaskConfigError(f"invalid pool: {config.pool.pool_name}")
|
||||
|
||||
check_target_exe(config, definition)
|
||||
|
||||
if TaskFeature.generator_exe in definition.features:
|
||||
container = [x for x in config.containers if x.type == ContainerType.tools][0]
|
||||
if not config.task.generator_exe:
|
||||
raise TaskConfigError("generator_exe is not defined")
|
||||
|
||||
tools_paths = ["{tools_dir}/", "{tools_dir}\\"]
|
||||
for tool_path in tools_paths:
|
||||
if config.task.generator_exe.startswith(tool_path):
|
||||
generator = config.task.generator_exe.replace(tool_path, "")
|
||||
if not blob_exists(container.name, generator, StorageType.corpus):
|
||||
err = (
|
||||
"generator_exe `%s` does not exist in the tools container `%s`"
|
||||
% (
|
||||
config.task.generator_exe,
|
||||
container.name,
|
||||
)
|
||||
)
|
||||
LOGGER.error(err)
|
||||
raise TaskConfigError(err)
|
||||
|
||||
if TaskFeature.stats_file in definition.features:
|
||||
if config.task.stats_file is not None and config.task.stats_format is None:
|
||||
err = "using a stats_file requires a stats_format"
|
||||
LOGGER.error(err)
|
||||
raise TaskConfigError(err)
|
||||
|
||||
|
||||
def build_task_config(job: Job, task: Task) -> TaskUnitConfig:
|
||||
job_id = job.job_id
|
||||
task_id = task.task_id
|
||||
task_config = task.config
|
||||
|
||||
if task_config.task.type not in TASK_DEFINITIONS:
|
||||
raise TaskConfigError("unsupported task type: %s" % task_config.task.type.name)
|
||||
|
||||
definition = TASK_DEFINITIONS[task_config.task.type]
|
||||
|
||||
config = TaskUnitConfig(
|
||||
job_id=job_id,
|
||||
task_id=task_id,
|
||||
task_type=task_config.task.type,
|
||||
instance_telemetry_key=os.environ.get("APPINSIGHTS_INSTRUMENTATIONKEY"),
|
||||
microsoft_telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"),
|
||||
heartbeat_queue=get_queue_sas(
|
||||
"task-heartbeat",
|
||||
StorageType.config,
|
||||
add=True,
|
||||
),
|
||||
instance_id=get_instance_id(),
|
||||
)
|
||||
|
||||
if job.config.logs:
|
||||
config.logs = add_container_sas_url(job.config.logs)
|
||||
else:
|
||||
LOGGER.warning("Missing log container: job_id %s, task_id %s", job_id, task_id)
|
||||
|
||||
if definition.monitor_queue:
|
||||
config.input_queue = get_queue_sas(
|
||||
task_id,
|
||||
StorageType.corpus,
|
||||
add=True,
|
||||
read=True,
|
||||
update=True,
|
||||
process=True,
|
||||
)
|
||||
|
||||
for container_def in definition.containers:
|
||||
if container_def.type == ContainerType.setup:
|
||||
continue
|
||||
|
||||
containers = []
|
||||
for (i, container) in enumerate(task_config.containers):
|
||||
if container.type != container_def.type:
|
||||
continue
|
||||
|
||||
containers.append(
|
||||
{
|
||||
"path": "_".join(["task", container_def.type.name, str(i)]),
|
||||
"url": get_container_sas_url(
|
||||
container.name,
|
||||
StorageType.corpus,
|
||||
read=ContainerPermission.Read in container_def.permissions,
|
||||
write=ContainerPermission.Write in container_def.permissions,
|
||||
delete=ContainerPermission.Delete in container_def.permissions,
|
||||
list_=ContainerPermission.List in container_def.permissions,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
if not containers:
|
||||
continue
|
||||
|
||||
if (
|
||||
container_def.compare in [Compare.Equal, Compare.AtMost]
|
||||
and container_def.value == 1
|
||||
):
|
||||
setattr(config, container_def.type.name, containers[0])
|
||||
else:
|
||||
setattr(config, container_def.type.name, containers)
|
||||
|
||||
EMPTY_DICT: Dict[str, str] = {}
|
||||
EMPTY_LIST: List[str] = []
|
||||
|
||||
if TaskFeature.supervisor_exe in definition.features:
|
||||
config.supervisor_exe = task_config.task.supervisor_exe
|
||||
|
||||
if TaskFeature.supervisor_env in definition.features:
|
||||
config.supervisor_env = task_config.task.supervisor_env or EMPTY_DICT
|
||||
|
||||
if TaskFeature.supervisor_options in definition.features:
|
||||
config.supervisor_options = task_config.task.supervisor_options or EMPTY_LIST
|
||||
|
||||
if TaskFeature.supervisor_input_marker in definition.features:
|
||||
config.supervisor_input_marker = task_config.task.supervisor_input_marker
|
||||
|
||||
if TaskFeature.target_exe in definition.features:
|
||||
config.target_exe = task_config.task.target_exe
|
||||
|
||||
if (
|
||||
TaskFeature.target_exe_optional in definition.features
|
||||
and task_config.task.target_exe
|
||||
):
|
||||
config.target_exe = task_config.task.target_exe
|
||||
|
||||
if TaskFeature.target_env in definition.features:
|
||||
config.target_env = task_config.task.target_env or EMPTY_DICT
|
||||
|
||||
if TaskFeature.target_options in definition.features:
|
||||
config.target_options = task_config.task.target_options or EMPTY_LIST
|
||||
|
||||
if TaskFeature.target_options_merge in definition.features:
|
||||
config.target_options_merge = task_config.task.target_options_merge or False
|
||||
|
||||
if TaskFeature.target_workers in definition.features:
|
||||
config.target_workers = task_config.task.target_workers
|
||||
|
||||
if TaskFeature.rename_output in definition.features:
|
||||
config.rename_output = task_config.task.rename_output or False
|
||||
|
||||
if TaskFeature.generator_exe in definition.features:
|
||||
config.generator_exe = task_config.task.generator_exe
|
||||
|
||||
if TaskFeature.generator_env in definition.features:
|
||||
config.generator_env = task_config.task.generator_env or EMPTY_DICT
|
||||
|
||||
if TaskFeature.generator_options in definition.features:
|
||||
config.generator_options = task_config.task.generator_options or EMPTY_LIST
|
||||
|
||||
if (
|
||||
TaskFeature.wait_for_files in definition.features
|
||||
and task_config.task.wait_for_files
|
||||
):
|
||||
config.wait_for_files = task_config.task.wait_for_files.name
|
||||
|
||||
if TaskFeature.analyzer_exe in definition.features:
|
||||
config.analyzer_exe = task_config.task.analyzer_exe
|
||||
|
||||
if TaskFeature.analyzer_options in definition.features:
|
||||
config.analyzer_options = task_config.task.analyzer_options or EMPTY_LIST
|
||||
|
||||
if TaskFeature.analyzer_env in definition.features:
|
||||
config.analyzer_env = task_config.task.analyzer_env or EMPTY_DICT
|
||||
|
||||
if TaskFeature.stats_file in definition.features:
|
||||
config.stats_file = task_config.task.stats_file
|
||||
config.stats_format = task_config.task.stats_format
|
||||
|
||||
if TaskFeature.target_timeout in definition.features:
|
||||
config.target_timeout = task_config.task.target_timeout
|
||||
|
||||
if TaskFeature.check_asan_log in definition.features:
|
||||
config.check_asan_log = task_config.task.check_asan_log
|
||||
|
||||
if TaskFeature.check_debugger in definition.features:
|
||||
config.check_debugger = task_config.task.check_debugger
|
||||
|
||||
if TaskFeature.check_retry_count in definition.features:
|
||||
config.check_retry_count = task_config.task.check_retry_count or 0
|
||||
|
||||
if TaskFeature.ensemble_sync_delay in definition.features:
|
||||
config.ensemble_sync_delay = task_config.task.ensemble_sync_delay
|
||||
|
||||
if TaskFeature.check_fuzzer_help in definition.features:
|
||||
config.check_fuzzer_help = (
|
||||
task_config.task.check_fuzzer_help
|
||||
if task_config.task.check_fuzzer_help is not None
|
||||
else True
|
||||
)
|
||||
|
||||
if TaskFeature.report_list in definition.features:
|
||||
config.report_list = task_config.task.report_list
|
||||
|
||||
if TaskFeature.minimized_stack_depth in definition.features:
|
||||
config.minimized_stack_depth = task_config.task.minimized_stack_depth
|
||||
|
||||
if TaskFeature.expect_crash_on_failure in definition.features:
|
||||
config.expect_crash_on_failure = (
|
||||
task_config.task.expect_crash_on_failure
|
||||
if task_config.task.expect_crash_on_failure is not None
|
||||
else True
|
||||
)
|
||||
|
||||
if TaskFeature.coverage_filter in definition.features:
|
||||
coverage_filter = task_config.task.coverage_filter
|
||||
|
||||
if coverage_filter is not None:
|
||||
config.coverage_filter = coverage_filter
|
||||
|
||||
if TaskFeature.target_assembly in definition.features:
|
||||
config.target_assembly = task_config.task.target_assembly
|
||||
|
||||
if TaskFeature.target_class in definition.features:
|
||||
config.target_class = task_config.task.target_class
|
||||
|
||||
if TaskFeature.target_method in definition.features:
|
||||
config.target_method = task_config.task.target_method
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_setup_container(config: TaskConfig) -> Container:
|
||||
for container in config.containers:
|
||||
if container.type == ContainerType.setup:
|
||||
return container.name
|
||||
|
||||
raise TaskConfigError(
|
||||
"task missing setup container: task_type = %s" % config.task.type
|
||||
)
|
||||
|
||||
|
||||
class TaskConfigError(Exception):
|
||||
pass
|
@ -1,707 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from onefuzztypes.enums import (
|
||||
Compare,
|
||||
ContainerPermission,
|
||||
ContainerType,
|
||||
TaskFeature,
|
||||
TaskType,
|
||||
)
|
||||
from onefuzztypes.models import ContainerDefinition, TaskDefinition, VmDefinition
|
||||
|
||||
# all tasks are required to have a 'setup' container
|
||||
TASK_DEFINITIONS = {
|
||||
TaskType.coverage: TaskDefinition(
|
||||
features=[
|
||||
TaskFeature.target_exe,
|
||||
TaskFeature.target_env,
|
||||
TaskFeature.target_options,
|
||||
TaskFeature.target_timeout,
|
||||
TaskFeature.coverage_filter,
|
||||
TaskFeature.target_must_use_input,
|
||||
],
|
||||
vm=VmDefinition(compare=Compare.Equal, value=1),
|
||||
containers=[
|
||||
ContainerDefinition(
|
||||
type=ContainerType.setup,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.readonly_inputs,
|
||||
compare=Compare.AtLeast,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.coverage,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[
|
||||
ContainerPermission.List,
|
||||
ContainerPermission.Read,
|
||||
ContainerPermission.Write,
|
||||
],
|
||||
),
|
||||
],
|
||||
monitor_queue=ContainerType.readonly_inputs,
|
||||
),
|
||||
TaskType.dotnet_coverage: TaskDefinition(
|
||||
features=[
|
||||
TaskFeature.target_exe,
|
||||
TaskFeature.target_env,
|
||||
TaskFeature.target_options,
|
||||
TaskFeature.target_timeout,
|
||||
TaskFeature.coverage_filter,
|
||||
TaskFeature.target_must_use_input,
|
||||
],
|
||||
vm=VmDefinition(compare=Compare.Equal, value=1),
|
||||
containers=[
|
||||
ContainerDefinition(
|
||||
type=ContainerType.setup,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.readonly_inputs,
|
||||
compare=Compare.AtLeast,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.coverage,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[
|
||||
ContainerPermission.List,
|
||||
ContainerPermission.Read,
|
||||
ContainerPermission.Write,
|
||||
],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.tools,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
],
|
||||
monitor_queue=ContainerType.readonly_inputs,
|
||||
),
|
||||
TaskType.dotnet_crash_report: TaskDefinition(
|
||||
features=[
|
||||
TaskFeature.target_exe,
|
||||
TaskFeature.target_env,
|
||||
TaskFeature.target_options,
|
||||
TaskFeature.target_timeout,
|
||||
TaskFeature.check_asan_log,
|
||||
TaskFeature.check_debugger,
|
||||
TaskFeature.check_retry_count,
|
||||
TaskFeature.minimized_stack_depth,
|
||||
],
|
||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
||||
containers=[
|
||||
ContainerDefinition(
|
||||
type=ContainerType.setup,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.crashes,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.reports,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Write],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.unique_reports,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Write],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.no_repro,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Write],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.tools,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
],
|
||||
monitor_queue=ContainerType.crashes,
|
||||
),
|
||||
TaskType.libfuzzer_dotnet_fuzz: TaskDefinition(
|
||||
features=[
|
||||
TaskFeature.target_exe,
|
||||
TaskFeature.target_env,
|
||||
TaskFeature.target_options,
|
||||
TaskFeature.target_workers,
|
||||
TaskFeature.ensemble_sync_delay,
|
||||
TaskFeature.check_fuzzer_help,
|
||||
TaskFeature.expect_crash_on_failure,
|
||||
TaskFeature.target_assembly,
|
||||
TaskFeature.target_class,
|
||||
TaskFeature.target_method,
|
||||
],
|
||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
||||
containers=[
|
||||
ContainerDefinition(
|
||||
type=ContainerType.setup,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.crashes,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Write],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.inputs,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[
|
||||
ContainerPermission.Write,
|
||||
ContainerPermission.Read,
|
||||
ContainerPermission.List,
|
||||
],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.readonly_inputs,
|
||||
compare=Compare.AtLeast,
|
||||
value=0,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.tools,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
],
|
||||
monitor_queue=None,
|
||||
),
|
||||
TaskType.generic_analysis: TaskDefinition(
|
||||
features=[
|
||||
TaskFeature.target_exe,
|
||||
TaskFeature.target_options,
|
||||
TaskFeature.analyzer_exe,
|
||||
TaskFeature.analyzer_env,
|
||||
TaskFeature.analyzer_options,
|
||||
],
|
||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
||||
containers=[
|
||||
ContainerDefinition(
|
||||
type=ContainerType.setup,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.analysis,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[
|
||||
ContainerPermission.Write,
|
||||
ContainerPermission.Read,
|
||||
ContainerPermission.List,
|
||||
],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.crashes,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.tools,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
],
|
||||
monitor_queue=ContainerType.crashes,
|
||||
),
|
||||
TaskType.libfuzzer_fuzz: TaskDefinition(
|
||||
features=[
|
||||
TaskFeature.target_exe,
|
||||
TaskFeature.target_env,
|
||||
TaskFeature.target_options,
|
||||
TaskFeature.target_workers,
|
||||
TaskFeature.ensemble_sync_delay,
|
||||
TaskFeature.check_fuzzer_help,
|
||||
TaskFeature.expect_crash_on_failure,
|
||||
],
|
||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
||||
containers=[
|
||||
ContainerDefinition(
|
||||
type=ContainerType.setup,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.crashes,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Write],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.inputs,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[
|
||||
ContainerPermission.Write,
|
||||
ContainerPermission.Read,
|
||||
ContainerPermission.List,
|
||||
],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.readonly_inputs,
|
||||
compare=Compare.AtLeast,
|
||||
value=0,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
],
|
||||
monitor_queue=None,
|
||||
),
|
||||
TaskType.libfuzzer_crash_report: TaskDefinition(
|
||||
features=[
|
||||
TaskFeature.target_exe,
|
||||
TaskFeature.target_env,
|
||||
TaskFeature.target_options,
|
||||
TaskFeature.target_timeout,
|
||||
TaskFeature.check_retry_count,
|
||||
TaskFeature.check_fuzzer_help,
|
||||
TaskFeature.minimized_stack_depth,
|
||||
],
|
||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
||||
containers=[
|
||||
ContainerDefinition(
|
||||
type=ContainerType.setup,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.crashes,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.reports,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Write],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.unique_reports,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Write],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.no_repro,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Write],
|
||||
),
|
||||
],
|
||||
monitor_queue=ContainerType.crashes,
|
||||
),
|
||||
TaskType.libfuzzer_merge: TaskDefinition(
|
||||
features=[
|
||||
TaskFeature.target_exe,
|
||||
TaskFeature.target_env,
|
||||
TaskFeature.target_options,
|
||||
TaskFeature.check_fuzzer_help,
|
||||
],
|
||||
vm=VmDefinition(compare=Compare.Equal, value=1),
|
||||
containers=[
|
||||
ContainerDefinition(
|
||||
type=ContainerType.setup,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.unique_inputs,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[
|
||||
ContainerPermission.List,
|
||||
ContainerPermission.Read,
|
||||
ContainerPermission.Write,
|
||||
],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.inputs,
|
||||
compare=Compare.AtLeast,
|
||||
value=0,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
],
|
||||
monitor_queue=None,
|
||||
),
|
||||
TaskType.generic_supervisor: TaskDefinition(
|
||||
features=[
|
||||
TaskFeature.target_exe,
|
||||
TaskFeature.target_options,
|
||||
TaskFeature.supervisor_exe,
|
||||
TaskFeature.supervisor_env,
|
||||
TaskFeature.supervisor_options,
|
||||
TaskFeature.supervisor_input_marker,
|
||||
TaskFeature.wait_for_files,
|
||||
TaskFeature.stats_file,
|
||||
TaskFeature.ensemble_sync_delay,
|
||||
],
|
||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
||||
containers=[
|
||||
ContainerDefinition(
|
||||
type=ContainerType.setup,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.tools,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.crashes,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Write],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.inputs,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[
|
||||
ContainerPermission.Write,
|
||||
ContainerPermission.Read,
|
||||
ContainerPermission.List,
|
||||
],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.unique_reports,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[
|
||||
ContainerPermission.Write,
|
||||
ContainerPermission.Read,
|
||||
ContainerPermission.List,
|
||||
],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.reports,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[
|
||||
ContainerPermission.Write,
|
||||
ContainerPermission.Read,
|
||||
ContainerPermission.List,
|
||||
],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.no_repro,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[
|
||||
ContainerPermission.Write,
|
||||
ContainerPermission.Read,
|
||||
ContainerPermission.List,
|
||||
],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.coverage,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[
|
||||
ContainerPermission.Write,
|
||||
ContainerPermission.Read,
|
||||
ContainerPermission.List,
|
||||
],
|
||||
),
|
||||
],
|
||||
monitor_queue=None,
|
||||
),
|
||||
TaskType.generic_merge: TaskDefinition(
|
||||
features=[
|
||||
TaskFeature.target_exe,
|
||||
TaskFeature.target_options,
|
||||
TaskFeature.supervisor_exe,
|
||||
TaskFeature.supervisor_env,
|
||||
TaskFeature.supervisor_options,
|
||||
TaskFeature.supervisor_input_marker,
|
||||
TaskFeature.stats_file,
|
||||
TaskFeature.preserve_existing_outputs,
|
||||
],
|
||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
||||
containers=[
|
||||
ContainerDefinition(
|
||||
type=ContainerType.setup,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.tools,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.readonly_inputs,
|
||||
compare=Compare.AtLeast,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.inputs,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Write, ContainerPermission.List],
|
||||
),
|
||||
],
|
||||
monitor_queue=None,
|
||||
),
|
||||
TaskType.generic_generator: TaskDefinition(
|
||||
features=[
|
||||
TaskFeature.generator_exe,
|
||||
TaskFeature.generator_env,
|
||||
TaskFeature.generator_options,
|
||||
TaskFeature.target_exe,
|
||||
TaskFeature.target_env,
|
||||
TaskFeature.target_options,
|
||||
TaskFeature.rename_output,
|
||||
TaskFeature.target_timeout,
|
||||
TaskFeature.check_asan_log,
|
||||
TaskFeature.check_debugger,
|
||||
TaskFeature.check_retry_count,
|
||||
TaskFeature.ensemble_sync_delay,
|
||||
TaskFeature.target_must_use_input,
|
||||
],
|
||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
||||
containers=[
|
||||
ContainerDefinition(
|
||||
type=ContainerType.setup,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.tools,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.crashes,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Write],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.readonly_inputs,
|
||||
compare=Compare.AtLeast,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
],
|
||||
monitor_queue=None,
|
||||
),
|
||||
TaskType.generic_crash_report: TaskDefinition(
|
||||
features=[
|
||||
TaskFeature.target_exe,
|
||||
TaskFeature.target_env,
|
||||
TaskFeature.target_options,
|
||||
TaskFeature.target_timeout,
|
||||
TaskFeature.check_asan_log,
|
||||
TaskFeature.check_debugger,
|
||||
TaskFeature.check_retry_count,
|
||||
TaskFeature.minimized_stack_depth,
|
||||
],
|
||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
||||
containers=[
|
||||
ContainerDefinition(
|
||||
type=ContainerType.setup,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.crashes,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.reports,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Write],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.unique_reports,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Write],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.no_repro,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Write],
|
||||
),
|
||||
],
|
||||
monitor_queue=ContainerType.crashes,
|
||||
),
|
||||
TaskType.generic_regression: TaskDefinition(
|
||||
features=[
|
||||
TaskFeature.target_exe,
|
||||
TaskFeature.target_env,
|
||||
TaskFeature.target_options,
|
||||
TaskFeature.target_timeout,
|
||||
TaskFeature.check_asan_log,
|
||||
TaskFeature.check_debugger,
|
||||
TaskFeature.check_retry_count,
|
||||
TaskFeature.report_list,
|
||||
TaskFeature.minimized_stack_depth,
|
||||
],
|
||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
||||
containers=[
|
||||
ContainerDefinition(
|
||||
type=ContainerType.setup,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.regression_reports,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[
|
||||
ContainerPermission.Write,
|
||||
ContainerPermission.Read,
|
||||
ContainerPermission.List,
|
||||
],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.crashes,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.reports,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.unique_reports,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.no_repro,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.readonly_inputs,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[
|
||||
ContainerPermission.Read,
|
||||
ContainerPermission.List,
|
||||
],
|
||||
),
|
||||
],
|
||||
),
|
||||
TaskType.libfuzzer_regression: TaskDefinition(
|
||||
features=[
|
||||
TaskFeature.target_exe,
|
||||
TaskFeature.target_env,
|
||||
TaskFeature.target_options,
|
||||
TaskFeature.target_timeout,
|
||||
TaskFeature.check_fuzzer_help,
|
||||
TaskFeature.check_retry_count,
|
||||
TaskFeature.report_list,
|
||||
TaskFeature.minimized_stack_depth,
|
||||
],
|
||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
||||
containers=[
|
||||
ContainerDefinition(
|
||||
type=ContainerType.setup,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.regression_reports,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[
|
||||
ContainerPermission.Write,
|
||||
ContainerPermission.Read,
|
||||
ContainerPermission.List,
|
||||
],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.crashes,
|
||||
compare=Compare.Equal,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.unique_reports,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.reports,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.no_repro,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
||||
),
|
||||
ContainerDefinition(
|
||||
type=ContainerType.readonly_inputs,
|
||||
compare=Compare.AtMost,
|
||||
value=1,
|
||||
permissions=[
|
||||
ContainerPermission.Read,
|
||||
ContainerPermission.List,
|
||||
],
|
||||
),
|
||||
],
|
||||
),
|
||||
}
|
@ -1,360 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from uuid import UUID
|
||||
|
||||
from onefuzztypes.enums import ErrorCode, TaskState
|
||||
from onefuzztypes.events import (
|
||||
EventTaskCreated,
|
||||
EventTaskFailed,
|
||||
EventTaskStateUpdated,
|
||||
EventTaskStopped,
|
||||
)
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.models import Task as BASE_TASK
|
||||
from onefuzztypes.models import TaskConfig, TaskVm, UserInfo
|
||||
from onefuzztypes.primitives import PoolName
|
||||
|
||||
from ..azure.image import get_os
|
||||
from ..azure.queue import create_queue, delete_queue
|
||||
from ..azure.storage import StorageType
|
||||
from ..events import send_event
|
||||
from ..orm import MappingIntStrAny, ORMMixin, QueryFilter
|
||||
from ..workers.nodes import Node, NodeTasks
|
||||
from ..workers.pools import Pool
|
||||
from ..workers.scalesets import Scaleset
|
||||
|
||||
|
||||
class Task(BASE_TASK, ORMMixin):
|
||||
def check_prereq_tasks(self) -> bool:
|
||||
if self.config.prereq_tasks:
|
||||
for task_id in self.config.prereq_tasks:
|
||||
task = Task.get_by_task_id(task_id)
|
||||
# if a prereq task fails, then mark this task as failed
|
||||
if isinstance(task, Error):
|
||||
self.mark_failed(task)
|
||||
return False
|
||||
|
||||
if task.state not in task.state.has_started():
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls, config: TaskConfig, job_id: UUID, user_info: UserInfo
|
||||
) -> Union["Task", Error]:
|
||||
if config.vm:
|
||||
os = get_os(config.vm.region, config.vm.image)
|
||||
if isinstance(os, Error):
|
||||
return os
|
||||
elif config.pool:
|
||||
pool = Pool.get_by_name(config.pool.pool_name)
|
||||
if isinstance(pool, Error):
|
||||
return pool
|
||||
os = pool.os
|
||||
else:
|
||||
raise Exception("task must have vm or pool")
|
||||
task = cls(config=config, job_id=job_id, os=os, user_info=user_info)
|
||||
task.save()
|
||||
send_event(
|
||||
EventTaskCreated(
|
||||
job_id=task.job_id,
|
||||
task_id=task.task_id,
|
||||
config=config,
|
||||
user_info=user_info,
|
||||
)
|
||||
)
|
||||
|
||||
logging.info(
|
||||
"created task. job_id:%s task_id:%s type:%s",
|
||||
task.job_id,
|
||||
task.task_id,
|
||||
task.config.task.type.name,
|
||||
)
|
||||
return task
|
||||
|
||||
def save_exclude(self) -> Optional[MappingIntStrAny]:
|
||||
return {"heartbeats": ...}
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
if self.config.prereq_tasks:
|
||||
for prereq_id in self.config.prereq_tasks:
|
||||
prereq = Task.get_by_task_id(prereq_id)
|
||||
if isinstance(prereq, Error):
|
||||
logging.info("task prereq has error: %s - %s", self.task_id, prereq)
|
||||
self.mark_failed(prereq)
|
||||
return False
|
||||
if prereq.state != TaskState.running:
|
||||
logging.info(
|
||||
"task is waiting on prereq: %s - %s:",
|
||||
self.task_id,
|
||||
prereq.task_id,
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
# At current, the telemetry filter will generate something like this:
|
||||
#
|
||||
# {
|
||||
# 'job_id': 'f4a20fd8-0dcc-4a4e-8804-6ef7df50c978',
|
||||
# 'task_id': '835f7b3f-43ad-4718-b7e4-d506d9667b09',
|
||||
# 'state': 'stopped',
|
||||
# 'config': {
|
||||
# 'task': {'type': 'coverage'},
|
||||
# 'vm': {'count': 1}
|
||||
# }
|
||||
# }
|
||||
def telemetry_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
"job_id": ...,
|
||||
"task_id": ...,
|
||||
"state": ...,
|
||||
"config": {"vm": {"count": ...}, "task": {"type": ...}},
|
||||
}
|
||||
|
||||
def init(self) -> None:
|
||||
create_queue(self.task_id, StorageType.corpus)
|
||||
self.set_state(TaskState.waiting)
|
||||
|
||||
def stopping(self) -> None:
|
||||
logging.info("stopping task: %s:%s", self.job_id, self.task_id)
|
||||
Node.stop_task(self.task_id)
|
||||
if not NodeTasks.get_nodes_by_task_id(self.task_id):
|
||||
self.stopped()
|
||||
|
||||
def stopped(self) -> None:
|
||||
self.set_state(TaskState.stopped)
|
||||
delete_queue(str(self.task_id), StorageType.corpus)
|
||||
|
||||
# TODO: we need to 'unschedule' this task from the existing pools
|
||||
from ..jobs import Job
|
||||
|
||||
job = Job.get(self.job_id)
|
||||
if job:
|
||||
job.stop_if_all_done()
|
||||
|
||||
@classmethod
|
||||
def search_states(
|
||||
cls, *, job_id: Optional[UUID] = None, states: Optional[List[TaskState]] = None
|
||||
) -> List["Task"]:
|
||||
query: QueryFilter = {}
|
||||
if job_id:
|
||||
query["job_id"] = [job_id]
|
||||
if states:
|
||||
query["state"] = states
|
||||
|
||||
return cls.search(query=query)
|
||||
|
||||
@classmethod
|
||||
def search_expired(cls) -> List["Task"]:
|
||||
time_filter = "end_time lt datetime'%s'" % datetime.utcnow().isoformat()
|
||||
return cls.search(
|
||||
query={"state": TaskState.available()}, raw_unchecked_filter=time_filter
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_by_task_id(cls, task_id: UUID) -> Union[Error, "Task"]:
|
||||
tasks = cls.search(query={"task_id": [task_id]})
|
||||
if not tasks:
|
||||
return Error(code=ErrorCode.INVALID_REQUEST, errors=["unable to find task"])
|
||||
|
||||
if len(tasks) != 1:
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST, errors=["error identifying task"]
|
||||
)
|
||||
task = tasks[0]
|
||||
return task
|
||||
|
||||
@classmethod
|
||||
def get_tasks_by_pool_name(cls, pool_name: PoolName) -> List["Task"]:
|
||||
tasks = cls.search_states(states=TaskState.available())
|
||||
if not tasks:
|
||||
return []
|
||||
|
||||
pool_tasks = []
|
||||
|
||||
for task in tasks:
|
||||
task_pool = task.get_pool()
|
||||
if not task_pool:
|
||||
continue
|
||||
if pool_name == task_pool.name:
|
||||
pool_tasks.append(task)
|
||||
|
||||
return pool_tasks
|
||||
|
||||
def mark_stopping(self) -> None:
|
||||
if self.state in TaskState.shutting_down():
|
||||
logging.debug(
|
||||
"ignoring post-task stop calls to stop %s:%s", self.job_id, self.task_id
|
||||
)
|
||||
return
|
||||
|
||||
if self.state not in TaskState.has_started():
|
||||
self.mark_failed(
|
||||
Error(code=ErrorCode.TASK_FAILED, errors=["task never started"])
|
||||
)
|
||||
|
||||
self.set_state(TaskState.stopping)
|
||||
|
||||
def mark_failed(
|
||||
self, error: Error, tasks_in_job: Optional[List["Task"]] = None
|
||||
) -> None:
|
||||
if self.state in TaskState.shutting_down():
|
||||
logging.debug(
|
||||
"ignoring post-task stop failures for %s:%s", self.job_id, self.task_id
|
||||
)
|
||||
return
|
||||
|
||||
if self.error is not None:
|
||||
logging.debug(
|
||||
"ignoring additional task error %s:%s", self.job_id, self.task_id
|
||||
)
|
||||
return
|
||||
|
||||
logging.error("task failed %s:%s - %s", self.job_id, self.task_id, error)
|
||||
|
||||
self.error = error
|
||||
self.set_state(TaskState.stopping)
|
||||
|
||||
self.mark_dependants_failed(tasks_in_job=tasks_in_job)
|
||||
|
||||
def mark_dependants_failed(
|
||||
self, tasks_in_job: Optional[List["Task"]] = None
|
||||
) -> None:
|
||||
if tasks_in_job is None:
|
||||
tasks_in_job = Task.search(query={"job_id": [self.job_id]})
|
||||
|
||||
for task in tasks_in_job:
|
||||
if task.config.prereq_tasks:
|
||||
if self.task_id in task.config.prereq_tasks:
|
||||
task.mark_failed(
|
||||
Error(
|
||||
code=ErrorCode.TASK_FAILED,
|
||||
errors=[
|
||||
"prerequisite task failed. task_id:%s" % self.task_id
|
||||
],
|
||||
),
|
||||
tasks_in_job,
|
||||
)
|
||||
|
||||
def get_pool(self) -> Optional[Pool]:
|
||||
if self.config.pool:
|
||||
pool = Pool.get_by_name(self.config.pool.pool_name)
|
||||
if isinstance(pool, Error):
|
||||
logging.info(
|
||||
"unable to schedule task to pool: %s - %s", self.task_id, pool
|
||||
)
|
||||
return None
|
||||
return pool
|
||||
elif self.config.vm:
|
||||
scalesets = Scaleset.search()
|
||||
scalesets = [
|
||||
x
|
||||
for x in scalesets
|
||||
if x.vm_sku == self.config.vm.sku and x.image == self.config.vm.image
|
||||
]
|
||||
for scaleset in scalesets:
|
||||
pool = Pool.get_by_name(scaleset.pool_name)
|
||||
if isinstance(pool, Error):
|
||||
logging.info(
|
||||
"unable to schedule task to pool: %s - %s",
|
||||
self.task_id,
|
||||
pool,
|
||||
)
|
||||
else:
|
||||
return pool
|
||||
|
||||
logging.warning(
|
||||
"unable to find a scaleset that matches the task prereqs: %s",
|
||||
self.task_id,
|
||||
)
|
||||
return None
|
||||
|
||||
def get_repro_vm_config(self) -> Union[TaskVm, None]:
|
||||
if self.config.vm:
|
||||
return self.config.vm
|
||||
|
||||
if self.config.pool is None:
|
||||
raise Exception("either pool or vm must be specified: %s" % self.task_id)
|
||||
|
||||
pool = Pool.get_by_name(self.config.pool.pool_name)
|
||||
if isinstance(pool, Error):
|
||||
logging.info("unable to find pool from task: %s", self.task_id)
|
||||
return None
|
||||
|
||||
for scaleset in Scaleset.search_by_pool(self.config.pool.pool_name):
|
||||
return TaskVm(
|
||||
region=scaleset.region,
|
||||
sku=scaleset.vm_sku,
|
||||
image=scaleset.image,
|
||||
)
|
||||
|
||||
logging.warning(
|
||||
"no scalesets are defined for task: %s:%s", self.job_id, self.task_id
|
||||
)
|
||||
return None
|
||||
|
||||
def on_start(self) -> None:
|
||||
# try to keep this effectively idempotent
|
||||
|
||||
if self.end_time is None:
|
||||
self.end_time = datetime.utcnow() + timedelta(
|
||||
hours=self.config.task.duration
|
||||
)
|
||||
|
||||
from ..jobs import Job
|
||||
|
||||
job = Job.get(self.job_id)
|
||||
if job:
|
||||
job.on_start()
|
||||
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, str]:
|
||||
return ("job_id", "task_id")
|
||||
|
||||
def set_state(self, state: TaskState) -> None:
|
||||
if self.state == state:
|
||||
return
|
||||
|
||||
self.state = state
|
||||
if self.state in [TaskState.running, TaskState.setting_up]:
|
||||
self.on_start()
|
||||
|
||||
self.save()
|
||||
|
||||
if self.state == TaskState.stopped:
|
||||
if self.error:
|
||||
send_event(
|
||||
EventTaskFailed(
|
||||
job_id=self.job_id,
|
||||
task_id=self.task_id,
|
||||
error=self.error,
|
||||
user_info=self.user_info,
|
||||
config=self.config,
|
||||
)
|
||||
)
|
||||
else:
|
||||
send_event(
|
||||
EventTaskStopped(
|
||||
job_id=self.job_id,
|
||||
task_id=self.task_id,
|
||||
user_info=self.user_info,
|
||||
config=self.config,
|
||||
)
|
||||
)
|
||||
else:
|
||||
send_event(
|
||||
EventTaskStateUpdated(
|
||||
job_id=self.job_id,
|
||||
task_id=self.task_id,
|
||||
state=self.state,
|
||||
end_time=self.end_time,
|
||||
config=self.config,
|
||||
)
|
||||
)
|
@ -1,248 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Generator, List, Optional, Tuple, TypeVar
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from onefuzztypes.enums import OS, PoolState, TaskState
|
||||
from onefuzztypes.models import WorkSet, WorkUnit
|
||||
from onefuzztypes.primitives import Container
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..azure.containers import blob_exists, get_container_sas_url
|
||||
from ..azure.storage import StorageType
|
||||
from ..jobs import Job
|
||||
from ..workers.pools import Pool
|
||||
from .config import build_task_config, get_setup_container
|
||||
from .main import Task
|
||||
|
||||
HOURS = 60 * 60
|
||||
|
||||
# TODO: eventually, this should be tied to the pool.
|
||||
MAX_TASKS_PER_SET = 10
|
||||
|
||||
|
||||
A = TypeVar("A")
|
||||
|
||||
|
||||
def chunks(items: List[A], size: int) -> Generator[List[A], None, None]:
|
||||
return (items[x : x + size] for x in range(0, len(items), size))
|
||||
|
||||
|
||||
def schedule_workset(workset: WorkSet, pool: Pool, count: int) -> bool:
|
||||
if pool.state not in PoolState.available():
|
||||
logging.info(
|
||||
"pool not available for work: %s state: %s", pool.name, pool.state.name
|
||||
)
|
||||
return False
|
||||
|
||||
for _ in range(count):
|
||||
if not pool.schedule_workset(workset):
|
||||
logging.error(
|
||||
"unable to schedule workset. pool:%s workset:%s", pool.name, workset
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# TODO - Once Pydantic supports hashable models, the Tuple should be replaced
|
||||
# with a model.
|
||||
#
|
||||
# For info: https://github.com/samuelcolvin/pydantic/pull/1881
|
||||
|
||||
|
||||
def bucket_tasks(tasks: List[Task]) -> Dict[Tuple, List[Task]]:
|
||||
# buckets are hashed by:
|
||||
# OS, JOB ID, vm sku & image (if available), pool name (if available),
|
||||
# if the setup script requires rebooting, and a 'unique' value
|
||||
#
|
||||
# The unique value is set based on the following conditions:
|
||||
# * if the task is set to run on more than one VM, than we assume it can't be shared
|
||||
# * if the task is missing the 'colocate' flag or it's set to False
|
||||
|
||||
buckets: Dict[Tuple, List[Task]] = {}
|
||||
|
||||
for task in tasks:
|
||||
vm: Optional[Tuple[str, str]] = None
|
||||
pool: Optional[str] = None
|
||||
unique: Optional[UUID] = None
|
||||
|
||||
# check for multiple VMs for pre-1.0.0 tasks
|
||||
if task.config.vm:
|
||||
vm = (task.config.vm.sku, task.config.vm.image)
|
||||
if task.config.vm.count > 1:
|
||||
unique = uuid4()
|
||||
|
||||
# check for multiple VMs for 1.0.0 and later tasks
|
||||
if task.config.pool:
|
||||
pool = task.config.pool.pool_name
|
||||
if task.config.pool.count > 1:
|
||||
unique = uuid4()
|
||||
|
||||
if not task.config.colocate:
|
||||
unique = uuid4()
|
||||
|
||||
key = (
|
||||
task.os,
|
||||
task.job_id,
|
||||
vm,
|
||||
pool,
|
||||
get_setup_container(task.config),
|
||||
task.config.task.reboot_after_setup,
|
||||
unique,
|
||||
)
|
||||
if key not in buckets:
|
||||
buckets[key] = []
|
||||
buckets[key].append(task)
|
||||
|
||||
return buckets
|
||||
|
||||
|
||||
class BucketConfig(BaseModel):
|
||||
count: int
|
||||
reboot: bool
|
||||
setup_container: Container
|
||||
setup_script: Optional[str]
|
||||
pool: Pool
|
||||
|
||||
|
||||
def build_work_unit(task: Task) -> Optional[Tuple[BucketConfig, WorkUnit]]:
|
||||
pool = task.get_pool()
|
||||
if not pool:
|
||||
logging.info("unable to find pool for task: %s", task.task_id)
|
||||
return None
|
||||
|
||||
logging.info("scheduling task: %s", task.task_id)
|
||||
|
||||
job = Job.get(task.job_id)
|
||||
if not job:
|
||||
raise Exception(f"invalid job_id {task.job_id} for task {task.task_id}")
|
||||
|
||||
task_config = build_task_config(job, task)
|
||||
|
||||
setup_container = get_setup_container(task.config)
|
||||
setup_script = None
|
||||
|
||||
if task.os == OS.windows and blob_exists(
|
||||
setup_container, "setup.ps1", StorageType.corpus
|
||||
):
|
||||
setup_script = "setup.ps1"
|
||||
if task.os == OS.linux and blob_exists(
|
||||
setup_container, "setup.sh", StorageType.corpus
|
||||
):
|
||||
setup_script = "setup.sh"
|
||||
|
||||
reboot = False
|
||||
count = 1
|
||||
if task.config.pool:
|
||||
count = task.config.pool.count
|
||||
|
||||
# NOTE: "is True" is required to handle Optional[bool]
|
||||
reboot = task.config.task.reboot_after_setup is True
|
||||
elif task.config.vm:
|
||||
# this branch should go away when we stop letting people specify
|
||||
# VM configs directly.
|
||||
count = task.config.vm.count
|
||||
|
||||
# NOTE: "is True" is required to handle Optional[bool]
|
||||
reboot = (
|
||||
task.config.vm.reboot_after_setup is True
|
||||
or task.config.task.reboot_after_setup is True
|
||||
)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
work_unit = WorkUnit(
|
||||
job_id=task_config.job_id,
|
||||
task_id=task_config.task_id,
|
||||
task_type=task_config.task_type,
|
||||
config=task_config.json(exclude_none=True, exclude_unset=True),
|
||||
)
|
||||
|
||||
bucket_config = BucketConfig(
|
||||
pool=pool,
|
||||
count=count,
|
||||
reboot=reboot,
|
||||
setup_script=setup_script,
|
||||
setup_container=setup_container,
|
||||
)
|
||||
|
||||
return bucket_config, work_unit
|
||||
|
||||
|
||||
def build_work_set(tasks: List[Task]) -> Optional[Tuple[BucketConfig, WorkSet]]:
|
||||
task_ids = [x.task_id for x in tasks]
|
||||
|
||||
bucket_config: Optional[BucketConfig] = None
|
||||
work_units = []
|
||||
|
||||
for task in tasks:
|
||||
if task.config.prereq_tasks:
|
||||
# if all of the prereqs are in this bucket, they will be
|
||||
# scheduled together
|
||||
if not all([task_id in task_ids for task_id in task.config.prereq_tasks]):
|
||||
if not task.check_prereq_tasks():
|
||||
continue
|
||||
|
||||
result = build_work_unit(task)
|
||||
if not result:
|
||||
continue
|
||||
|
||||
new_bucket_config, work_unit = result
|
||||
if bucket_config is None:
|
||||
bucket_config = new_bucket_config
|
||||
else:
|
||||
if bucket_config != new_bucket_config:
|
||||
raise Exception(
|
||||
f"bucket configs differ: {bucket_config} VS {new_bucket_config}"
|
||||
)
|
||||
|
||||
work_units.append(work_unit)
|
||||
|
||||
if bucket_config:
|
||||
setup_url = get_container_sas_url(
|
||||
bucket_config.setup_container, StorageType.corpus, read=True, list_=True
|
||||
)
|
||||
|
||||
work_set = WorkSet(
|
||||
reboot=bucket_config.reboot,
|
||||
script=(bucket_config.setup_script is not None),
|
||||
setup_url=setup_url,
|
||||
work_units=work_units,
|
||||
)
|
||||
return (bucket_config, work_set)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def schedule_tasks() -> None:
|
||||
tasks: List[Task] = []
|
||||
|
||||
tasks = Task.search_states(states=[TaskState.waiting])
|
||||
|
||||
tasks_by_id = {x.task_id: x for x in tasks}
|
||||
seen = set()
|
||||
|
||||
not_ready_count = 0
|
||||
|
||||
buckets = bucket_tasks(tasks)
|
||||
|
||||
for bucketed_tasks in buckets.values():
|
||||
for chunk in chunks(bucketed_tasks, MAX_TASKS_PER_SET):
|
||||
result = build_work_set(chunk)
|
||||
if result is None:
|
||||
continue
|
||||
bucket_config, work_set = result
|
||||
|
||||
if schedule_workset(work_set, bucket_config.pool, bucket_config.count):
|
||||
for work_unit in work_set.work_units:
|
||||
task = tasks_by_id[work_unit.task_id]
|
||||
task.set_state(TaskState.scheduled)
|
||||
seen.add(task.task_id)
|
||||
|
||||
not_ready_count = len(tasks) - len(seen)
|
||||
if not_ready_count > 0:
|
||||
logging.info("tasks not ready: %d", not_ready_count)
|
@ -1,72 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from onefuzztypes.enums import TelemetryData, TelemetryEvent
|
||||
from opencensus.ext.azure.log_exporter import AzureLogHandler
|
||||
|
||||
LOCAL_CLIENT: Optional[logging.Logger] = None
|
||||
CENTRAL_CLIENT: Optional[logging.Logger] = None
|
||||
|
||||
|
||||
def _get_client(environ_key: str) -> Optional[logging.Logger]:
|
||||
key = os.environ.get(environ_key)
|
||||
if key is None:
|
||||
return None
|
||||
client = logging.getLogger("onefuzz")
|
||||
client.addHandler(AzureLogHandler(connection_string="InstrumentationKey=%s" % key))
|
||||
return client
|
||||
|
||||
|
||||
def _central_client() -> Optional[logging.Logger]:
|
||||
global CENTRAL_CLIENT
|
||||
if not CENTRAL_CLIENT:
|
||||
CENTRAL_CLIENT = _get_client("ONEFUZZ_TELEMETRY")
|
||||
return CENTRAL_CLIENT
|
||||
|
||||
|
||||
def _local_client() -> Union[None, Any, logging.Logger]:
|
||||
global LOCAL_CLIENT
|
||||
if not LOCAL_CLIENT:
|
||||
LOCAL_CLIENT = _get_client("APPINSIGHTS_INSTRUMENTATIONKEY")
|
||||
return LOCAL_CLIENT
|
||||
|
||||
|
||||
# NOTE: All telemetry that is *NOT* using the ORM telemetry_include should
|
||||
# go through this method
|
||||
#
|
||||
# This provides a point of inspection to know if it's data that is safe to
|
||||
# log to the central OneFuzz telemetry point
|
||||
def track_event(
|
||||
event: TelemetryEvent, data: Dict[TelemetryData, Union[str, int]]
|
||||
) -> None:
|
||||
central = _central_client()
|
||||
local = _local_client()
|
||||
|
||||
if local:
|
||||
serialized = {k.name: v for (k, v) in data.items()}
|
||||
local.info(event.name, extra={"custom_dimensions": serialized})
|
||||
|
||||
if event in TelemetryEvent.can_share() and central:
|
||||
serialized = {
|
||||
k.name: v for (k, v) in data.items() if k in TelemetryData.can_share()
|
||||
}
|
||||
central.info(event.name, extra={"custom_dimensions": serialized})
|
||||
|
||||
|
||||
# NOTE: This should *only* be used for logging Telemetry data that uses
|
||||
# the ORM telemetry_include method to limit data for telemetry.
|
||||
def track_event_filtered(event: TelemetryEvent, data: Any) -> None:
|
||||
central = _central_client()
|
||||
local = _local_client()
|
||||
|
||||
if local:
|
||||
local.info(event.name, extra={"custom_dimensions": data})
|
||||
|
||||
if central and event in TelemetryEvent.can_share():
|
||||
central.info(event.name, extra={"custom_dimensions": data})
|
@ -1,114 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Optional, Type
|
||||
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from msrestazure.azure_exceptions import CloudError
|
||||
from onefuzztypes.enums import UpdateType
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .azure.queue import queue_object
|
||||
from .azure.storage import StorageType
|
||||
|
||||
|
||||
# This class isn't intended to be shared outside of the service
|
||||
class Update(BaseModel):
|
||||
update_type: UpdateType
|
||||
PartitionKey: Optional[str]
|
||||
RowKey: Optional[str]
|
||||
method: Optional[str]
|
||||
|
||||
|
||||
def queue_update(
|
||||
update_type: UpdateType,
|
||||
PartitionKey: Optional[str] = None,
|
||||
RowKey: Optional[str] = None,
|
||||
method: Optional[str] = None,
|
||||
visibility_timeout: int = None,
|
||||
) -> None:
|
||||
logging.info(
|
||||
"queuing type:%s id:[%s,%s] method:%s timeout: %s",
|
||||
update_type.name,
|
||||
PartitionKey,
|
||||
RowKey,
|
||||
method,
|
||||
visibility_timeout,
|
||||
)
|
||||
|
||||
update = Update(
|
||||
update_type=update_type, PartitionKey=PartitionKey, RowKey=RowKey, method=method
|
||||
)
|
||||
|
||||
try:
|
||||
if not queue_object(
|
||||
"update-queue",
|
||||
update,
|
||||
StorageType.config,
|
||||
visibility_timeout=visibility_timeout,
|
||||
):
|
||||
logging.error("unable to queue update")
|
||||
except (CloudError, ResourceNotFoundError) as err:
|
||||
logging.error("GOT ERROR %s", repr(err))
|
||||
logging.error("GOT ERROR %s", dir(err))
|
||||
raise err
|
||||
|
||||
|
||||
def execute_update(update: Update) -> None:
|
||||
from .jobs import Job
|
||||
from .orm import ORMMixin
|
||||
from .proxy import Proxy
|
||||
from .repro import Repro
|
||||
from .tasks.main import Task
|
||||
from .workers.nodes import Node
|
||||
from .workers.pools import Pool
|
||||
from .workers.scalesets import Scaleset
|
||||
|
||||
update_objects: Dict[UpdateType, Type[ORMMixin]] = {
|
||||
UpdateType.Task: Task,
|
||||
UpdateType.Job: Job,
|
||||
UpdateType.Repro: Repro,
|
||||
UpdateType.Proxy: Proxy,
|
||||
UpdateType.Pool: Pool,
|
||||
UpdateType.Node: Node,
|
||||
UpdateType.Scaleset: Scaleset,
|
||||
}
|
||||
|
||||
# TODO: remove these from being queued, these updates are handled elsewhere
|
||||
if update.update_type == UpdateType.Scaleset:
|
||||
return
|
||||
|
||||
if update.update_type in update_objects:
|
||||
if update.PartitionKey is None or update.RowKey is None:
|
||||
raise Exception("unsupported update: %s" % update)
|
||||
|
||||
obj = update_objects[update.update_type].get(update.PartitionKey, update.RowKey)
|
||||
if not obj:
|
||||
logging.error("unable find to obj to update %s", update)
|
||||
return
|
||||
|
||||
if update.method and hasattr(obj, update.method):
|
||||
logging.info("performing queued update: %s", update)
|
||||
getattr(obj, update.method)()
|
||||
return
|
||||
else:
|
||||
state = getattr(obj, "state", None)
|
||||
if state is None:
|
||||
logging.error("queued update for object without state: %s", update)
|
||||
return
|
||||
func = getattr(obj, state.name, None)
|
||||
if func is None:
|
||||
logging.debug(
|
||||
"no function to implement state: %s - %s", update, state.name
|
||||
)
|
||||
return
|
||||
logging.info(
|
||||
"performing queued update for state: %s - %s", update, state.name
|
||||
)
|
||||
func()
|
||||
return
|
||||
|
||||
raise NotImplementedError("unimplemented update type: %s" % update.update_type.name)
|
@ -1,79 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
import azure.functions as func
|
||||
import jwt
|
||||
from memoization import cached
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error, Result, UserInfo
|
||||
|
||||
from .config import InstanceConfig
|
||||
|
||||
|
||||
def get_bearer_token(request: func.HttpRequest) -> Optional[str]:
|
||||
auth: str = request.headers.get("Authorization", None)
|
||||
if not auth:
|
||||
return None
|
||||
|
||||
parts = auth.split()
|
||||
|
||||
if len(parts) != 2:
|
||||
return None
|
||||
|
||||
if parts[0].lower() != "bearer":
|
||||
return None
|
||||
|
||||
return parts[1]
|
||||
|
||||
|
||||
def get_auth_token(request: func.HttpRequest) -> Optional[str]:
|
||||
token = get_bearer_token(request)
|
||||
if token is not None:
|
||||
return token
|
||||
|
||||
token_header = request.headers.get("x-ms-token-aad-id-token", None)
|
||||
if token_header is None:
|
||||
return None
|
||||
return str(token_header)
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_allowed_tenants() -> List[str]:
|
||||
config = InstanceConfig.fetch()
|
||||
entries = [f"https://sts.windows.net/{x}/" for x in config.allowed_aad_tenants]
|
||||
return entries
|
||||
|
||||
|
||||
def parse_jwt_token(request: func.HttpRequest) -> Result[UserInfo]:
|
||||
"""Obtains the Access Token from the Authorization Header"""
|
||||
token_str = get_auth_token(request)
|
||||
if token_str is None:
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
errors=["unable to find authorization token"],
|
||||
)
|
||||
|
||||
# The JWT token has already been verified by the azure authentication layer,
|
||||
# but we need to verify the tenant is as we expect.
|
||||
token = jwt.decode(token_str, options={"verify_signature": False})
|
||||
|
||||
if "iss" not in token:
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST, errors=["missing issuer from token"]
|
||||
)
|
||||
|
||||
tenants = get_allowed_tenants()
|
||||
if token["iss"] not in tenants:
|
||||
logging.error("issuer not from allowed tenant: %s - %s", token["iss"], tenants)
|
||||
return Error(code=ErrorCode.INVALID_REQUEST, errors=["unauthorized AAD issuer"])
|
||||
|
||||
application_id = UUID(token["appid"]) if "appid" in token else None
|
||||
object_id = UUID(token["oid"]) if "oid" in token else None
|
||||
upn = token.get("upn")
|
||||
return UserInfo(application_id=application_id, object_id=object_id, upn=upn)
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user