mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-15 19:38:11 +00:00
@ -29,13 +29,6 @@ echo "layout python3" >> .envrc
|
|||||||
direnv allow
|
direnv allow
|
||||||
pip install -e .
|
pip install -e .
|
||||||
|
|
||||||
echo "Install api-service"
|
|
||||||
cd /workspaces/onefuzz/src/api-service
|
|
||||||
echo "layout python3" >> .envrc
|
|
||||||
direnv allow
|
|
||||||
pip install -r requirements-dev.txt
|
|
||||||
cd __app__
|
|
||||||
pip install -r requirements.txt
|
|
||||||
|
|
||||||
cd /workspaces/onefuzz/src/utils
|
cd /workspaces/onefuzz/src/utils
|
||||||
chmod u+x lint.sh
|
chmod u+x lint.sh
|
||||||
|
1
.github/codeql/codeql-config.yml
vendored
1
.github/codeql/codeql-config.yml
vendored
@ -3,4 +3,3 @@ paths:
|
|||||||
- src/agent
|
- src/agent
|
||||||
- src/pytypes
|
- src/pytypes
|
||||||
- src/deployment
|
- src/deployment
|
||||||
- src/api-service/__app__
|
|
||||||
|
@ -1,7 +0,0 @@
|
|||||||
[flake8]
|
|
||||||
# Recommend matching the black line length (default 88),
|
|
||||||
# rather than using the flake8 default of 79:
|
|
||||||
max-line-length = 88
|
|
||||||
extend-ignore =
|
|
||||||
# See https://github.com/PyCQA/pycodestyle/issues/373
|
|
||||||
E203,
|
|
49
src/api-service/.gitignore
vendored
49
src/api-service/.gitignore
vendored
@ -1,49 +0,0 @@
|
|||||||
bin
|
|
||||||
obj
|
|
||||||
csx
|
|
||||||
.vs
|
|
||||||
edge
|
|
||||||
Publish
|
|
||||||
|
|
||||||
*.user
|
|
||||||
*.suo
|
|
||||||
*.cscfg
|
|
||||||
*.Cache
|
|
||||||
project.lock.json
|
|
||||||
|
|
||||||
/packages
|
|
||||||
/TestResults
|
|
||||||
|
|
||||||
/tools/NuGet.exe
|
|
||||||
/App_Data
|
|
||||||
/secrets
|
|
||||||
/data
|
|
||||||
.secrets
|
|
||||||
appsettings.json
|
|
||||||
local.settings.json
|
|
||||||
|
|
||||||
node_modules
|
|
||||||
dist
|
|
||||||
|
|
||||||
# Local python packages
|
|
||||||
.python_packages/
|
|
||||||
|
|
||||||
# Python Environments
|
|
||||||
.env
|
|
||||||
.venv
|
|
||||||
env/
|
|
||||||
venv/
|
|
||||||
ENV/
|
|
||||||
env.bak/
|
|
||||||
venv.bak/
|
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
|
||||||
__pycache__/
|
|
||||||
*.py[cod]
|
|
||||||
*$py.class
|
|
||||||
.dockerenv
|
|
||||||
host_secrets.json
|
|
||||||
local.settings.json
|
|
||||||
.mypy_cache/
|
|
||||||
|
|
||||||
__app__/onefuzzlib/build.id
|
|
@ -1,9 +0,0 @@
|
|||||||
If you are doing development on the API service, you can build and deploy directly to your own instance using the azure-functions-core-tools.
|
|
||||||
|
|
||||||
From the api-service directory, do the following:
|
|
||||||
|
|
||||||
func azure functionapp publish <instance>
|
|
||||||
|
|
||||||
While Azure Functions will restart your instance with the new code, it may take a while. It may be helpful to restart your instance after pushing by doing the following:
|
|
||||||
|
|
||||||
az functionapp restart -g <group> -n <instance>
|
|
@ -1 +0,0 @@
|
|||||||
local.settings.json
|
|
4
src/api-service/__app__/.gitignore
vendored
4
src/api-service/__app__/.gitignore
vendored
@ -1,4 +0,0 @@
|
|||||||
.direnv
|
|
||||||
.python_packages
|
|
||||||
__pycache__
|
|
||||||
.venv
|
|
@ -1,17 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
import certifi.core
|
|
||||||
|
|
||||||
|
|
||||||
def override_where() -> str:
|
|
||||||
"""overrides certifi.core.where to return actual location of cacert.pem"""
|
|
||||||
# see:
|
|
||||||
# https://github.com/Azure/azure-functions-durable-python/issues/194#issuecomment-710670377
|
|
||||||
# change this to match the location of cacert.pem
|
|
||||||
return os.path.abspath(
|
|
||||||
"cacert.pem"
|
|
||||||
) # or to whatever location you know contains the copy of cacert.pem
|
|
||||||
|
|
||||||
|
|
||||||
os.environ["REQUESTS_CA_BUNDLE"] = override_where()
|
|
||||||
certifi.core.where = override_where
|
|
@ -1,53 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import azure.functions as func
|
|
||||||
from onefuzztypes.enums import ErrorCode, TaskState
|
|
||||||
from onefuzztypes.models import Error
|
|
||||||
from onefuzztypes.requests import CanScheduleRequest
|
|
||||||
from onefuzztypes.responses import CanSchedule
|
|
||||||
|
|
||||||
from ..onefuzzlib.endpoint_authorization import call_if_agent
|
|
||||||
from ..onefuzzlib.request import not_ok, ok, parse_request
|
|
||||||
from ..onefuzzlib.tasks.main import Task
|
|
||||||
from ..onefuzzlib.workers.nodes import Node
|
|
||||||
|
|
||||||
|
|
||||||
def post(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
request = parse_request(CanScheduleRequest, req)
|
|
||||||
if isinstance(request, Error):
|
|
||||||
return not_ok(request, context="CanScheduleRequest")
|
|
||||||
|
|
||||||
node = Node.get_by_machine_id(request.machine_id)
|
|
||||||
if not node:
|
|
||||||
return not_ok(
|
|
||||||
Error(code=ErrorCode.UNABLE_TO_FIND, errors=["unable to find node"]),
|
|
||||||
context=request.machine_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
allowed = True
|
|
||||||
work_stopped = False
|
|
||||||
|
|
||||||
if not node.can_process_new_work():
|
|
||||||
allowed = False
|
|
||||||
|
|
||||||
task = Task.get_by_task_id(request.task_id)
|
|
||||||
|
|
||||||
work_stopped = isinstance(task, Error) or task.state in TaskState.shutting_down()
|
|
||||||
if work_stopped:
|
|
||||||
allowed = False
|
|
||||||
|
|
||||||
if allowed:
|
|
||||||
allowed = not isinstance(node.acquire_scale_in_protection(), Error)
|
|
||||||
|
|
||||||
return ok(CanSchedule(allowed=allowed, work_stopped=work_stopped))
|
|
||||||
|
|
||||||
|
|
||||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
methods = {"POST": post}
|
|
||||||
method = methods[req.method]
|
|
||||||
result = call_if_agent(req, method)
|
|
||||||
|
|
||||||
return result
|
|
@ -1,20 +0,0 @@
|
|||||||
{
|
|
||||||
"scriptFile": "__init__.py",
|
|
||||||
"bindings": [
|
|
||||||
{
|
|
||||||
"authLevel": "anonymous",
|
|
||||||
"type": "httpTrigger",
|
|
||||||
"direction": "in",
|
|
||||||
"name": "req",
|
|
||||||
"methods": [
|
|
||||||
"post"
|
|
||||||
],
|
|
||||||
"route": "agents/can_schedule"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "http",
|
|
||||||
"direction": "out",
|
|
||||||
"name": "$return"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
@ -1,50 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import azure.functions as func
|
|
||||||
from onefuzztypes.models import Error, NodeCommandEnvelope
|
|
||||||
from onefuzztypes.requests import NodeCommandDelete, NodeCommandGet
|
|
||||||
from onefuzztypes.responses import BoolResult, PendingNodeCommand
|
|
||||||
|
|
||||||
from ..onefuzzlib.endpoint_authorization import call_if_agent
|
|
||||||
from ..onefuzzlib.request import not_ok, ok, parse_request
|
|
||||||
from ..onefuzzlib.workers.nodes import NodeMessage
|
|
||||||
|
|
||||||
|
|
||||||
def get(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
request = parse_request(NodeCommandGet, req)
|
|
||||||
|
|
||||||
if isinstance(request, Error):
|
|
||||||
return not_ok(request, context="NodeCommandGet")
|
|
||||||
|
|
||||||
messages = NodeMessage.get_messages(request.machine_id, num_results=1)
|
|
||||||
|
|
||||||
if messages:
|
|
||||||
command = messages[0].message
|
|
||||||
message_id = messages[0].message_id
|
|
||||||
envelope = NodeCommandEnvelope(command=command, message_id=message_id)
|
|
||||||
|
|
||||||
return ok(PendingNodeCommand(envelope=envelope))
|
|
||||||
else:
|
|
||||||
return ok(PendingNodeCommand(envelope=None))
|
|
||||||
|
|
||||||
|
|
||||||
def delete(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
request = parse_request(NodeCommandDelete, req)
|
|
||||||
if isinstance(request, Error):
|
|
||||||
return not_ok(request, context="NodeCommandDelete")
|
|
||||||
|
|
||||||
message = NodeMessage.get(request.machine_id, request.message_id)
|
|
||||||
if message:
|
|
||||||
message.delete()
|
|
||||||
return ok(BoolResult(result=True))
|
|
||||||
|
|
||||||
|
|
||||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
methods = {"DELETE": delete, "GET": get}
|
|
||||||
method = methods[req.method]
|
|
||||||
result = call_if_agent(req, method)
|
|
||||||
|
|
||||||
return result
|
|
@ -1,22 +0,0 @@
|
|||||||
{
|
|
||||||
"scriptFile": "__init__.py",
|
|
||||||
"bindings": [
|
|
||||||
{
|
|
||||||
"authLevel": "anonymous",
|
|
||||||
"type": "httpTrigger",
|
|
||||||
"direction": "in",
|
|
||||||
"name": "req",
|
|
||||||
"methods": [
|
|
||||||
"get",
|
|
||||||
"post",
|
|
||||||
"delete"
|
|
||||||
],
|
|
||||||
"route": "agents/commands"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "http",
|
|
||||||
"direction": "out",
|
|
||||||
"name": "$return"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
@ -1,79 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import azure.functions as func
|
|
||||||
from onefuzztypes.models import (
|
|
||||||
Error,
|
|
||||||
NodeEvent,
|
|
||||||
NodeEventEnvelope,
|
|
||||||
NodeStateUpdate,
|
|
||||||
Result,
|
|
||||||
WorkerEvent,
|
|
||||||
)
|
|
||||||
from onefuzztypes.responses import BoolResult
|
|
||||||
|
|
||||||
from ..onefuzzlib.agent_events import on_state_update, on_worker_event
|
|
||||||
from ..onefuzzlib.endpoint_authorization import call_if_agent
|
|
||||||
from ..onefuzzlib.request import not_ok, ok, parse_request
|
|
||||||
|
|
||||||
|
|
||||||
def process(envelope: NodeEventEnvelope) -> Result[None]:
|
|
||||||
logging.info(
|
|
||||||
"node event: machine_id: %s event: %s",
|
|
||||||
envelope.machine_id,
|
|
||||||
envelope.event.json(exclude_none=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(envelope.event, NodeStateUpdate):
|
|
||||||
return on_state_update(envelope.machine_id, envelope.event)
|
|
||||||
|
|
||||||
if isinstance(envelope.event, WorkerEvent):
|
|
||||||
return on_worker_event(envelope.machine_id, envelope.event)
|
|
||||||
|
|
||||||
if isinstance(envelope.event, NodeEvent):
|
|
||||||
if envelope.event.state_update:
|
|
||||||
result = on_state_update(envelope.machine_id, envelope.event.state_update)
|
|
||||||
if result is not None:
|
|
||||||
return result
|
|
||||||
|
|
||||||
if envelope.event.worker_event:
|
|
||||||
result = on_worker_event(envelope.machine_id, envelope.event.worker_event)
|
|
||||||
if result is not None:
|
|
||||||
return result
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
raise NotImplementedError("invalid node event: %s" % envelope)
|
|
||||||
|
|
||||||
|
|
||||||
def post(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
envelope = parse_request(NodeEventEnvelope, req)
|
|
||||||
if isinstance(envelope, Error):
|
|
||||||
return not_ok(envelope, context="node event")
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
"node event: machine_id: %s event: %s",
|
|
||||||
envelope.machine_id,
|
|
||||||
envelope.event.json(exclude_none=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
result = process(envelope)
|
|
||||||
if isinstance(result, Error):
|
|
||||||
logging.error(
|
|
||||||
"unable to process agent event. envelope:%s error:%s", envelope, result
|
|
||||||
)
|
|
||||||
return not_ok(result, context="node event")
|
|
||||||
|
|
||||||
return ok(BoolResult(result=True))
|
|
||||||
|
|
||||||
|
|
||||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
methods = {"POST": post}
|
|
||||||
method = methods[req.method]
|
|
||||||
result = call_if_agent(req, method)
|
|
||||||
|
|
||||||
return result
|
|
@ -1,20 +0,0 @@
|
|||||||
{
|
|
||||||
"scriptFile": "__init__.py",
|
|
||||||
"bindings": [
|
|
||||||
{
|
|
||||||
"authLevel": "anonymous",
|
|
||||||
"type": "httpTrigger",
|
|
||||||
"direction": "in",
|
|
||||||
"name": "req",
|
|
||||||
"methods": [
|
|
||||||
"post"
|
|
||||||
],
|
|
||||||
"route": "agents/events"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "http",
|
|
||||||
"direction": "out",
|
|
||||||
"name": "$return"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
@ -1,119 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import datetime
|
|
||||||
import logging
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
import azure.functions as func
|
|
||||||
from onefuzztypes.enums import ErrorCode
|
|
||||||
from onefuzztypes.models import Error
|
|
||||||
from onefuzztypes.requests import AgentRegistrationGet, AgentRegistrationPost
|
|
||||||
from onefuzztypes.responses import AgentRegistration
|
|
||||||
|
|
||||||
from ..onefuzzlib.azure.creds import get_agent_instance_url
|
|
||||||
from ..onefuzzlib.azure.queue import get_queue_sas
|
|
||||||
from ..onefuzzlib.azure.storage import StorageType
|
|
||||||
from ..onefuzzlib.endpoint_authorization import call_if_agent
|
|
||||||
from ..onefuzzlib.request import not_ok, ok, parse_uri
|
|
||||||
from ..onefuzzlib.workers.nodes import Node
|
|
||||||
from ..onefuzzlib.workers.pools import Pool
|
|
||||||
|
|
||||||
|
|
||||||
def create_registration_response(machine_id: UUID, pool: Pool) -> func.HttpResponse:
|
|
||||||
base_address = get_agent_instance_url()
|
|
||||||
events_url = "%s/api/agents/events" % base_address
|
|
||||||
commands_url = "%s/api/agents/commands" % base_address
|
|
||||||
work_queue = get_queue_sas(
|
|
||||||
pool.get_pool_queue(),
|
|
||||||
StorageType.corpus,
|
|
||||||
read=True,
|
|
||||||
update=True,
|
|
||||||
process=True,
|
|
||||||
duration=datetime.timedelta(hours=24),
|
|
||||||
)
|
|
||||||
return ok(
|
|
||||||
AgentRegistration(
|
|
||||||
events_url=events_url,
|
|
||||||
commands_url=commands_url,
|
|
||||||
work_queue=work_queue,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
get_registration = parse_uri(AgentRegistrationGet, req)
|
|
||||||
|
|
||||||
if isinstance(get_registration, Error):
|
|
||||||
return not_ok(get_registration, context="agent registration")
|
|
||||||
|
|
||||||
agent_node = Node.get_by_machine_id(get_registration.machine_id)
|
|
||||||
|
|
||||||
if agent_node is None:
|
|
||||||
return not_ok(
|
|
||||||
Error(
|
|
||||||
code=ErrorCode.INVALID_REQUEST,
|
|
||||||
errors=[
|
|
||||||
"unable to find a registration associated with machine_id '%s'"
|
|
||||||
% get_registration.machine_id
|
|
||||||
],
|
|
||||||
),
|
|
||||||
context="agent registration",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
pool = Pool.get_by_name(agent_node.pool_name)
|
|
||||||
if isinstance(pool, Error):
|
|
||||||
return not_ok(
|
|
||||||
Error(
|
|
||||||
code=ErrorCode.INVALID_REQUEST,
|
|
||||||
errors=[
|
|
||||||
"unable to find a pool associated with the provided machine_id"
|
|
||||||
],
|
|
||||||
),
|
|
||||||
context="agent registration",
|
|
||||||
)
|
|
||||||
|
|
||||||
return create_registration_response(agent_node.machine_id, pool)
|
|
||||||
|
|
||||||
|
|
||||||
def post(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
registration_request = parse_uri(AgentRegistrationPost, req)
|
|
||||||
if isinstance(registration_request, Error):
|
|
||||||
return not_ok(registration_request, context="agent registration")
|
|
||||||
logging.info(
|
|
||||||
"registration request: %s", registration_request.json(exclude_none=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
pool = Pool.get_by_name(registration_request.pool_name)
|
|
||||||
if isinstance(pool, Error):
|
|
||||||
return not_ok(
|
|
||||||
Error(
|
|
||||||
code=ErrorCode.INVALID_REQUEST,
|
|
||||||
errors=["unable to find pool '%s'" % registration_request.pool_name],
|
|
||||||
),
|
|
||||||
context="agent registration",
|
|
||||||
)
|
|
||||||
|
|
||||||
node = Node.get_by_machine_id(registration_request.machine_id)
|
|
||||||
if node:
|
|
||||||
node.delete()
|
|
||||||
|
|
||||||
node = Node.create(
|
|
||||||
pool_id=pool.pool_id,
|
|
||||||
pool_name=pool.name,
|
|
||||||
machine_id=registration_request.machine_id,
|
|
||||||
scaleset_id=registration_request.scaleset_id,
|
|
||||||
version=registration_request.version,
|
|
||||||
)
|
|
||||||
|
|
||||||
return create_registration_response(node.machine_id, pool)
|
|
||||||
|
|
||||||
|
|
||||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
methods = {"POST": post, "GET": get}
|
|
||||||
method = methods[req.method]
|
|
||||||
result = call_if_agent(req, method)
|
|
||||||
|
|
||||||
return result
|
|
@ -1,22 +0,0 @@
|
|||||||
{
|
|
||||||
"scriptFile": "__init__.py",
|
|
||||||
"bindings": [
|
|
||||||
{
|
|
||||||
"authLevel": "anonymous",
|
|
||||||
"type": "httpTrigger",
|
|
||||||
"direction": "in",
|
|
||||||
"name": "req",
|
|
||||||
"methods": [
|
|
||||||
"get",
|
|
||||||
"post",
|
|
||||||
"delete"
|
|
||||||
],
|
|
||||||
"route": "agents/registration"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "http",
|
|
||||||
"direction": "out",
|
|
||||||
"name": "$return"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
@ -1,96 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import azure.functions as func
|
|
||||||
from onefuzztypes.enums import ErrorCode
|
|
||||||
from onefuzztypes.models import Error
|
|
||||||
from onefuzztypes.requests import ContainerCreate, ContainerDelete, ContainerGet
|
|
||||||
from onefuzztypes.responses import BoolResult, ContainerInfo, ContainerInfoBase
|
|
||||||
|
|
||||||
from ..onefuzzlib.azure.containers import (
|
|
||||||
create_container,
|
|
||||||
delete_container,
|
|
||||||
get_container_metadata,
|
|
||||||
get_container_sas_url,
|
|
||||||
get_containers,
|
|
||||||
)
|
|
||||||
from ..onefuzzlib.azure.storage import StorageType
|
|
||||||
from ..onefuzzlib.endpoint_authorization import call_if_user
|
|
||||||
from ..onefuzzlib.request import not_ok, ok, parse_request
|
|
||||||
|
|
||||||
|
|
||||||
def get(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
request: Optional[Union[ContainerGet, Error]] = None
|
|
||||||
if req.get_body():
|
|
||||||
request = parse_request(ContainerGet, req)
|
|
||||||
|
|
||||||
if isinstance(request, Error):
|
|
||||||
return not_ok(request, context="container get")
|
|
||||||
if request is not None:
|
|
||||||
metadata = get_container_metadata(request.name, StorageType.corpus)
|
|
||||||
if metadata is None:
|
|
||||||
return not_ok(
|
|
||||||
Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid container"]),
|
|
||||||
context=request.name,
|
|
||||||
)
|
|
||||||
|
|
||||||
info = ContainerInfo(
|
|
||||||
name=request.name,
|
|
||||||
sas_url=get_container_sas_url(
|
|
||||||
request.name,
|
|
||||||
StorageType.corpus,
|
|
||||||
read=True,
|
|
||||||
write=True,
|
|
||||||
delete=True,
|
|
||||||
list_=True,
|
|
||||||
),
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
return ok(info)
|
|
||||||
|
|
||||||
containers = get_containers(StorageType.corpus)
|
|
||||||
|
|
||||||
container_info = []
|
|
||||||
for name in containers:
|
|
||||||
container_info.append(ContainerInfoBase(name=name, metadata=containers[name]))
|
|
||||||
|
|
||||||
return ok(container_info)
|
|
||||||
|
|
||||||
|
|
||||||
def post(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
request = parse_request(ContainerCreate, req)
|
|
||||||
if isinstance(request, Error):
|
|
||||||
return not_ok(request, context="container create")
|
|
||||||
|
|
||||||
logging.info("container - creating %s", request.name)
|
|
||||||
sas = create_container(request.name, StorageType.corpus, metadata=request.metadata)
|
|
||||||
if sas:
|
|
||||||
return ok(
|
|
||||||
ContainerInfo(name=request.name, sas_url=sas, metadata=request.metadata)
|
|
||||||
)
|
|
||||||
return not_ok(
|
|
||||||
Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid container"]),
|
|
||||||
context=request.name,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def delete(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
request = parse_request(ContainerDelete, req)
|
|
||||||
if isinstance(request, Error):
|
|
||||||
return not_ok(request, context="container delete")
|
|
||||||
|
|
||||||
logging.info("container - deleting %s", request.name)
|
|
||||||
return ok(BoolResult(result=delete_container(request.name, StorageType.corpus)))
|
|
||||||
|
|
||||||
|
|
||||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
methods = {"GET": get, "POST": post, "DELETE": delete}
|
|
||||||
method = methods[req.method]
|
|
||||||
result = call_if_user(req, method)
|
|
||||||
|
|
||||||
return result
|
|
@ -1,21 +0,0 @@
|
|||||||
{
|
|
||||||
"scriptFile": "__init__.py",
|
|
||||||
"bindings": [
|
|
||||||
{
|
|
||||||
"authLevel": "anonymous",
|
|
||||||
"type": "httpTrigger",
|
|
||||||
"direction": "in",
|
|
||||||
"name": "req",
|
|
||||||
"methods": [
|
|
||||||
"get",
|
|
||||||
"post",
|
|
||||||
"delete"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "http",
|
|
||||||
"direction": "out",
|
|
||||||
"name": "$return"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
@ -1,55 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
from datetime import timedelta
|
|
||||||
|
|
||||||
import azure.functions as func
|
|
||||||
from onefuzztypes.enums import ErrorCode
|
|
||||||
from onefuzztypes.models import Error, FileEntry
|
|
||||||
|
|
||||||
from ..onefuzzlib.azure.containers import (
|
|
||||||
blob_exists,
|
|
||||||
container_exists,
|
|
||||||
get_file_sas_url,
|
|
||||||
)
|
|
||||||
from ..onefuzzlib.azure.storage import StorageType
|
|
||||||
from ..onefuzzlib.endpoint_authorization import call_if_user
|
|
||||||
from ..onefuzzlib.request import not_ok, parse_uri, redirect
|
|
||||||
|
|
||||||
|
|
||||||
def get(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
request = parse_uri(FileEntry, req)
|
|
||||||
if isinstance(request, Error):
|
|
||||||
return not_ok(request, context="download")
|
|
||||||
|
|
||||||
if not container_exists(request.container, StorageType.corpus):
|
|
||||||
return not_ok(
|
|
||||||
Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid container"]),
|
|
||||||
context=request.container,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not blob_exists(request.container, request.filename, StorageType.corpus):
|
|
||||||
return not_ok(
|
|
||||||
Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid filename"]),
|
|
||||||
context=request.filename,
|
|
||||||
)
|
|
||||||
|
|
||||||
return redirect(
|
|
||||||
get_file_sas_url(
|
|
||||||
request.container,
|
|
||||||
request.filename,
|
|
||||||
StorageType.corpus,
|
|
||||||
read=True,
|
|
||||||
duration=timedelta(minutes=5),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
methods = {"GET": get}
|
|
||||||
method = methods[req.method]
|
|
||||||
result = call_if_user(req, method)
|
|
||||||
|
|
||||||
return result
|
|
@ -1,19 +0,0 @@
|
|||||||
{
|
|
||||||
"scriptFile": "__init__.py",
|
|
||||||
"bindings": [
|
|
||||||
{
|
|
||||||
"authLevel": "anonymous",
|
|
||||||
"type": "httpTrigger",
|
|
||||||
"direction": "in",
|
|
||||||
"name": "req",
|
|
||||||
"methods": [
|
|
||||||
"get"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "http",
|
|
||||||
"direction": "out",
|
|
||||||
"name": "$return"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
@ -1,34 +0,0 @@
|
|||||||
{
|
|
||||||
"version": "2.0",
|
|
||||||
"extensionBundle": {
|
|
||||||
"id": "Microsoft.Azure.Functions.ExtensionBundle",
|
|
||||||
"version": "[1.*, 2.0.0)"
|
|
||||||
},
|
|
||||||
"logging": {
|
|
||||||
"logLevel": {
|
|
||||||
"default": "Warning",
|
|
||||||
"Host": "Warning",
|
|
||||||
"Function": "Information",
|
|
||||||
"Host.Aggregator": "Warning",
|
|
||||||
"logging:logLevel:Worker": "Warning",
|
|
||||||
"logging:logLevel:Microsoft": "Warning",
|
|
||||||
"AzureFunctionsJobHost:logging:logLevel:Host.Function.Console": "Warning"
|
|
||||||
},
|
|
||||||
"applicationInsights": {
|
|
||||||
"samplingSettings": {
|
|
||||||
"isEnabled": true,
|
|
||||||
"InitialSamplingPercentage": 100,
|
|
||||||
"maxTelemetryItemsPerSecond": 20,
|
|
||||||
"excludedTypes": "Exception"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"extensions": {
|
|
||||||
"queues": {
|
|
||||||
"maxPollingInterval": "00:00:02",
|
|
||||||
"batchSize": 32,
|
|
||||||
"maxDequeueCount": 5
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"functionTimeout": "00:15:00"
|
|
||||||
}
|
|
@ -1,43 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import azure.functions as func
|
|
||||||
from onefuzztypes.responses import Info
|
|
||||||
|
|
||||||
from ..onefuzzlib.azure.creds import (
|
|
||||||
get_base_region,
|
|
||||||
get_base_resource_group,
|
|
||||||
get_insights_appid,
|
|
||||||
get_insights_instrumentation_key,
|
|
||||||
get_instance_id,
|
|
||||||
get_subscription,
|
|
||||||
)
|
|
||||||
from ..onefuzzlib.endpoint_authorization import call_if_user
|
|
||||||
from ..onefuzzlib.request import ok
|
|
||||||
from ..onefuzzlib.versions import versions
|
|
||||||
|
|
||||||
|
|
||||||
def get(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
response = ok(
|
|
||||||
Info(
|
|
||||||
resource_group=get_base_resource_group(),
|
|
||||||
region=get_base_region(),
|
|
||||||
subscription=get_subscription(),
|
|
||||||
versions=versions(),
|
|
||||||
instance_id=get_instance_id(),
|
|
||||||
insights_appid=get_insights_appid(),
|
|
||||||
insights_instrumentation_key=get_insights_instrumentation_key(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
methods = {"GET": get}
|
|
||||||
method = methods[req.method]
|
|
||||||
result = call_if_user(req, method)
|
|
||||||
|
|
||||||
return result
|
|
@ -1,19 +0,0 @@
|
|||||||
{
|
|
||||||
"scriptFile": "__init__.py",
|
|
||||||
"bindings": [
|
|
||||||
{
|
|
||||||
"authLevel": "anonymous",
|
|
||||||
"type": "httpTrigger",
|
|
||||||
"direction": "in",
|
|
||||||
"name": "req",
|
|
||||||
"methods": [
|
|
||||||
"get"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "http",
|
|
||||||
"direction": "out",
|
|
||||||
"name": "$return"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
@ -1,77 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import azure.functions as func
|
|
||||||
from onefuzztypes.enums import ErrorCode
|
|
||||||
from onefuzztypes.models import Error
|
|
||||||
from onefuzztypes.requests import InstanceConfigUpdate
|
|
||||||
|
|
||||||
from ..onefuzzlib.azure.nsg import is_onefuzz_nsg, list_nsgs, set_allowed
|
|
||||||
from ..onefuzzlib.config import InstanceConfig
|
|
||||||
from ..onefuzzlib.endpoint_authorization import call_if_user, can_modify_config
|
|
||||||
from ..onefuzzlib.request import not_ok, ok, parse_request
|
|
||||||
|
|
||||||
|
|
||||||
def get(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
return ok(InstanceConfig.fetch())
|
|
||||||
|
|
||||||
|
|
||||||
def post(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
request = parse_request(InstanceConfigUpdate, req)
|
|
||||||
if isinstance(request, Error):
|
|
||||||
return not_ok(request, context="instance_config_update")
|
|
||||||
|
|
||||||
config = InstanceConfig.fetch()
|
|
||||||
|
|
||||||
if not can_modify_config(req, config):
|
|
||||||
return not_ok(
|
|
||||||
Error(code=ErrorCode.INVALID_PERMISSION, errors=["unauthorized"]),
|
|
||||||
context="instance_config_update",
|
|
||||||
)
|
|
||||||
|
|
||||||
update_nsg = False
|
|
||||||
if request.config.proxy_nsg_config and config.proxy_nsg_config:
|
|
||||||
request_config = request.config.proxy_nsg_config
|
|
||||||
current_config = config.proxy_nsg_config
|
|
||||||
if set(request_config.allowed_service_tags) != set(
|
|
||||||
current_config.allowed_service_tags
|
|
||||||
) or set(request_config.allowed_ips) != set(current_config.allowed_ips):
|
|
||||||
update_nsg = True
|
|
||||||
|
|
||||||
config.update(request.config)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
# Update All NSGs
|
|
||||||
if update_nsg:
|
|
||||||
nsgs = list_nsgs()
|
|
||||||
for nsg in nsgs:
|
|
||||||
logging.info(
|
|
||||||
"Checking if nsg: %s (%s) owned by OneFuzz" % (nsg.location, nsg.name)
|
|
||||||
)
|
|
||||||
if is_onefuzz_nsg(nsg.location, nsg.name):
|
|
||||||
result = set_allowed(nsg.location, request.config.proxy_nsg_config)
|
|
||||||
if isinstance(result, Error):
|
|
||||||
return not_ok(
|
|
||||||
Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_CREATE,
|
|
||||||
errors=[
|
|
||||||
"Unable to update nsg %s due to %s"
|
|
||||||
% (nsg.location, result)
|
|
||||||
],
|
|
||||||
),
|
|
||||||
context="instance_config_update",
|
|
||||||
)
|
|
||||||
|
|
||||||
return ok(config)
|
|
||||||
|
|
||||||
|
|
||||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
methods = {"GET": get, "POST": post}
|
|
||||||
method = methods[req.method]
|
|
||||||
result = call_if_user(req, method)
|
|
||||||
|
|
||||||
return result
|
|
@ -1,20 +0,0 @@
|
|||||||
{
|
|
||||||
"scriptFile": "__init__.py",
|
|
||||||
"bindings": [
|
|
||||||
{
|
|
||||||
"authLevel": "anonymous",
|
|
||||||
"type": "httpTrigger",
|
|
||||||
"direction": "in",
|
|
||||||
"name": "req",
|
|
||||||
"methods": [
|
|
||||||
"get",
|
|
||||||
"post"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "http",
|
|
||||||
"direction": "out",
|
|
||||||
"name": "$return"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
@ -1,42 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import azure.functions as func
|
|
||||||
from onefuzztypes.job_templates import JobTemplateRequest
|
|
||||||
from onefuzztypes.models import Error
|
|
||||||
|
|
||||||
from ..onefuzzlib.endpoint_authorization import call_if_user
|
|
||||||
from ..onefuzzlib.job_templates.templates import JobTemplateIndex
|
|
||||||
from ..onefuzzlib.request import not_ok, ok, parse_request
|
|
||||||
from ..onefuzzlib.user_credentials import parse_jwt_token
|
|
||||||
|
|
||||||
|
|
||||||
def get(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
configs = JobTemplateIndex.get_configs()
|
|
||||||
return ok(configs)
|
|
||||||
|
|
||||||
|
|
||||||
def post(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
request = parse_request(JobTemplateRequest, req)
|
|
||||||
if isinstance(request, Error):
|
|
||||||
return not_ok(request, context="JobTemplateRequest")
|
|
||||||
|
|
||||||
user_info = parse_jwt_token(req)
|
|
||||||
if isinstance(user_info, Error):
|
|
||||||
return not_ok(user_info, context="JobTemplateRequest")
|
|
||||||
|
|
||||||
job = JobTemplateIndex.execute(request, user_info)
|
|
||||||
if isinstance(job, Error):
|
|
||||||
return not_ok(job, context="JobTemplateRequest")
|
|
||||||
|
|
||||||
return ok(job)
|
|
||||||
|
|
||||||
|
|
||||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
methods = {"GET": get, "POST": post}
|
|
||||||
method = methods[req.method]
|
|
||||||
result = call_if_user(req, method)
|
|
||||||
|
|
||||||
return result
|
|
@ -1,20 +0,0 @@
|
|||||||
{
|
|
||||||
"scriptFile": "__init__.py",
|
|
||||||
"bindings": [
|
|
||||||
{
|
|
||||||
"authLevel": "anonymous",
|
|
||||||
"type": "httpTrigger",
|
|
||||||
"direction": "in",
|
|
||||||
"name": "req",
|
|
||||||
"methods": [
|
|
||||||
"get",
|
|
||||||
"post"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "http",
|
|
||||||
"direction": "out",
|
|
||||||
"name": "$return"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
@ -1,69 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import azure.functions as func
|
|
||||||
from onefuzztypes.enums import ErrorCode
|
|
||||||
from onefuzztypes.job_templates import (
|
|
||||||
JobTemplateDelete,
|
|
||||||
JobTemplateGet,
|
|
||||||
JobTemplateUpload,
|
|
||||||
)
|
|
||||||
from onefuzztypes.models import Error
|
|
||||||
from onefuzztypes.responses import BoolResult
|
|
||||||
|
|
||||||
from ..onefuzzlib.endpoint_authorization import call_if_user
|
|
||||||
from ..onefuzzlib.job_templates.templates import JobTemplateIndex
|
|
||||||
from ..onefuzzlib.request import not_ok, ok, parse_request
|
|
||||||
|
|
||||||
|
|
||||||
def get(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
request = parse_request(JobTemplateGet, req)
|
|
||||||
if isinstance(request, Error):
|
|
||||||
return not_ok(request, context="JobTemplateGet")
|
|
||||||
|
|
||||||
if request.name:
|
|
||||||
entry = JobTemplateIndex.get_base_entry(request.name)
|
|
||||||
if entry is None:
|
|
||||||
return not_ok(
|
|
||||||
Error(code=ErrorCode.INVALID_REQUEST, errors=["no such job template"]),
|
|
||||||
context="JobTemplateGet",
|
|
||||||
)
|
|
||||||
return ok(entry.template)
|
|
||||||
|
|
||||||
templates = JobTemplateIndex.get_index()
|
|
||||||
return ok(templates)
|
|
||||||
|
|
||||||
|
|
||||||
def post(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
request = parse_request(JobTemplateUpload, req)
|
|
||||||
if isinstance(request, Error):
|
|
||||||
return not_ok(request, context="JobTemplateUpload")
|
|
||||||
|
|
||||||
entry = JobTemplateIndex(name=request.name, template=request.template)
|
|
||||||
result = entry.save()
|
|
||||||
if isinstance(result, Error):
|
|
||||||
return not_ok(result, context="JobTemplateUpload")
|
|
||||||
|
|
||||||
return ok(BoolResult(result=True))
|
|
||||||
|
|
||||||
|
|
||||||
def delete(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
request = parse_request(JobTemplateDelete, req)
|
|
||||||
if isinstance(request, Error):
|
|
||||||
return not_ok(request, context="JobTemplateDelete")
|
|
||||||
|
|
||||||
entry = JobTemplateIndex.get(request.name)
|
|
||||||
if entry is not None:
|
|
||||||
entry.delete()
|
|
||||||
|
|
||||||
return ok(BoolResult(result=entry is not None))
|
|
||||||
|
|
||||||
|
|
||||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
methods = {"GET": get, "POST": post, "DELETE": delete}
|
|
||||||
method = methods[req.method]
|
|
||||||
result = call_if_user(req, method)
|
|
||||||
|
|
||||||
return result
|
|
@ -1,22 +0,0 @@
|
|||||||
{
|
|
||||||
"scriptFile": "__init__.py",
|
|
||||||
"bindings": [
|
|
||||||
{
|
|
||||||
"authLevel": "anonymous",
|
|
||||||
"type": "httpTrigger",
|
|
||||||
"direction": "in",
|
|
||||||
"name": "req",
|
|
||||||
"methods": [
|
|
||||||
"get",
|
|
||||||
"post",
|
|
||||||
"delete"
|
|
||||||
],
|
|
||||||
"route": "job_templates/manage"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "http",
|
|
||||||
"direction": "out",
|
|
||||||
"name": "$return"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
@ -1,109 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
import azure.functions as func
|
|
||||||
from onefuzztypes.enums import ContainerType, ErrorCode, JobState
|
|
||||||
from onefuzztypes.models import Error, JobConfig, JobTaskInfo
|
|
||||||
from onefuzztypes.primitives import Container
|
|
||||||
from onefuzztypes.requests import JobGet, JobSearch
|
|
||||||
|
|
||||||
from ..onefuzzlib.azure.containers import create_container
|
|
||||||
from ..onefuzzlib.azure.storage import StorageType
|
|
||||||
from ..onefuzzlib.endpoint_authorization import call_if_user
|
|
||||||
from ..onefuzzlib.jobs import Job
|
|
||||||
from ..onefuzzlib.request import not_ok, ok, parse_request
|
|
||||||
from ..onefuzzlib.tasks.main import Task
|
|
||||||
from ..onefuzzlib.user_credentials import parse_jwt_token
|
|
||||||
|
|
||||||
|
|
||||||
def get(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
request = parse_request(JobSearch, req)
|
|
||||||
if isinstance(request, Error):
|
|
||||||
return not_ok(request, context="jobs")
|
|
||||||
|
|
||||||
if request.job_id:
|
|
||||||
job = Job.get(request.job_id)
|
|
||||||
if not job:
|
|
||||||
return not_ok(
|
|
||||||
Error(code=ErrorCode.INVALID_JOB, errors=["no such job"]),
|
|
||||||
context=request.job_id,
|
|
||||||
)
|
|
||||||
task_info = []
|
|
||||||
for task in Task.search_states(job_id=request.job_id):
|
|
||||||
task_info.append(
|
|
||||||
JobTaskInfo(
|
|
||||||
task_id=task.task_id, type=task.config.task.type, state=task.state
|
|
||||||
)
|
|
||||||
)
|
|
||||||
job.task_info = task_info
|
|
||||||
return ok(job)
|
|
||||||
|
|
||||||
jobs = Job.search_states(states=request.state)
|
|
||||||
return ok(jobs)
|
|
||||||
|
|
||||||
|
|
||||||
def post(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
request = parse_request(JobConfig, req)
|
|
||||||
if isinstance(request, Error):
|
|
||||||
return not_ok(request, context="jobs create")
|
|
||||||
|
|
||||||
user_info = parse_jwt_token(req)
|
|
||||||
if isinstance(user_info, Error):
|
|
||||||
return not_ok(user_info, context="jobs create")
|
|
||||||
|
|
||||||
job = Job(job_id=uuid4(), config=request, user_info=user_info)
|
|
||||||
# create the job logs container
|
|
||||||
log_container_sas = create_container(
|
|
||||||
Container(f"logs-{job.job_id}"),
|
|
||||||
StorageType.corpus,
|
|
||||||
metadata={"container_type": ContainerType.logs.name},
|
|
||||||
)
|
|
||||||
if not log_container_sas:
|
|
||||||
return not_ok(
|
|
||||||
Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_CREATE_CONTAINER,
|
|
||||||
errors=["unable to create logs container"],
|
|
||||||
),
|
|
||||||
context="logs",
|
|
||||||
)
|
|
||||||
sep_index = log_container_sas.find("?")
|
|
||||||
if sep_index > 0:
|
|
||||||
log_container = log_container_sas[:sep_index]
|
|
||||||
else:
|
|
||||||
log_container = log_container_sas
|
|
||||||
|
|
||||||
job.config.logs = log_container
|
|
||||||
job.save()
|
|
||||||
|
|
||||||
return ok(job)
|
|
||||||
|
|
||||||
|
|
||||||
def delete(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
request = parse_request(JobGet, req)
|
|
||||||
if isinstance(request, Error):
|
|
||||||
return not_ok(request, context="jobs delete")
|
|
||||||
|
|
||||||
job = Job.get(request.job_id)
|
|
||||||
if not job:
|
|
||||||
return not_ok(
|
|
||||||
Error(code=ErrorCode.INVALID_JOB, errors=["no such job"]),
|
|
||||||
context=request.job_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if job.state not in [JobState.stopped, JobState.stopping]:
|
|
||||||
job.state = JobState.stopping
|
|
||||||
job.save()
|
|
||||||
|
|
||||||
return ok(job)
|
|
||||||
|
|
||||||
|
|
||||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
methods = {"GET": get, "POST": post, "DELETE": delete}
|
|
||||||
method = methods[req.method]
|
|
||||||
result = call_if_user(req, method)
|
|
||||||
|
|
||||||
return result
|
|
@ -1,21 +0,0 @@
|
|||||||
{
|
|
||||||
"scriptFile": "__init__.py",
|
|
||||||
"bindings": [
|
|
||||||
{
|
|
||||||
"authLevel": "anonymous",
|
|
||||||
"type": "httpTrigger",
|
|
||||||
"direction": "in",
|
|
||||||
"name": "req",
|
|
||||||
"methods": [
|
|
||||||
"get",
|
|
||||||
"post",
|
|
||||||
"delete"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "http",
|
|
||||||
"direction": "out",
|
|
||||||
"name": "$return"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
@ -1,34 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import azure.functions as func
|
|
||||||
|
|
||||||
from ..onefuzzlib.endpoint_authorization import call_if_user
|
|
||||||
|
|
||||||
# This endpoint handles the signalr negotation
|
|
||||||
# As we do not differentiate from clients at this time, we pass the Functions runtime
|
|
||||||
# provided connection straight to the client
|
|
||||||
#
|
|
||||||
# For more info:
|
|
||||||
# https://docs.microsoft.com/en-us/azure/azure-signalr/signalr-concept-internals
|
|
||||||
|
|
||||||
|
|
||||||
def main(req: func.HttpRequest, connectionInfoJson: str) -> func.HttpResponse:
|
|
||||||
# NOTE: this is a sub-method because the call_if* do not support callbacks with
|
|
||||||
# additional arguments at this time. Once call_if* supports additional arguments,
|
|
||||||
# this should be made a generic function
|
|
||||||
def post(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
return func.HttpResponse(
|
|
||||||
connectionInfoJson,
|
|
||||||
status_code=200,
|
|
||||||
headers={"Content-type": "application/json"},
|
|
||||||
)
|
|
||||||
|
|
||||||
methods = {"POST": post}
|
|
||||||
method = methods[req.method]
|
|
||||||
|
|
||||||
result = call_if_user(req, method)
|
|
||||||
|
|
||||||
return result
|
|
@ -1,25 +0,0 @@
|
|||||||
{
|
|
||||||
"scriptFile": "__init__.py",
|
|
||||||
"bindings": [
|
|
||||||
{
|
|
||||||
"authLevel": "anonymous",
|
|
||||||
"type": "httpTrigger",
|
|
||||||
"direction": "in",
|
|
||||||
"name": "req",
|
|
||||||
"methods": [
|
|
||||||
"post"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "http",
|
|
||||||
"direction": "out",
|
|
||||||
"name": "$return"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "signalRConnectionInfo",
|
|
||||||
"direction": "in",
|
|
||||||
"name": "connectionInfoJson",
|
|
||||||
"hubName": "dashboard"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
@ -1,122 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import azure.functions as func
|
|
||||||
from onefuzztypes.enums import ErrorCode
|
|
||||||
from onefuzztypes.models import Error
|
|
||||||
from onefuzztypes.requests import NodeGet, NodeSearch, NodeUpdate
|
|
||||||
from onefuzztypes.responses import BoolResult
|
|
||||||
|
|
||||||
from ..onefuzzlib.endpoint_authorization import call_if_user, check_require_admins
|
|
||||||
from ..onefuzzlib.request import not_ok, ok, parse_request
|
|
||||||
from ..onefuzzlib.workers.nodes import Node, NodeMessage, NodeTasks
|
|
||||||
|
|
||||||
|
|
||||||
def get(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
request = parse_request(NodeSearch, req)
|
|
||||||
if isinstance(request, Error):
|
|
||||||
return not_ok(request, context="pool get")
|
|
||||||
|
|
||||||
if request.machine_id:
|
|
||||||
node = Node.get_by_machine_id(request.machine_id)
|
|
||||||
if not node:
|
|
||||||
return not_ok(
|
|
||||||
Error(code=ErrorCode.UNABLE_TO_FIND, errors=["unable to find node"]),
|
|
||||||
context=request.machine_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(node, Error):
|
|
||||||
return not_ok(node, context=request.machine_id)
|
|
||||||
|
|
||||||
node.tasks = [n for n in NodeTasks.get_by_machine_id(request.machine_id)]
|
|
||||||
node.messages = [
|
|
||||||
x.message for x in NodeMessage.get_messages(request.machine_id)
|
|
||||||
]
|
|
||||||
|
|
||||||
return ok(node)
|
|
||||||
|
|
||||||
nodes = Node.search_states(
|
|
||||||
states=request.state,
|
|
||||||
pool_name=request.pool_name,
|
|
||||||
scaleset_id=request.scaleset_id,
|
|
||||||
)
|
|
||||||
return ok(nodes)
|
|
||||||
|
|
||||||
|
|
||||||
def post(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
request = parse_request(NodeUpdate, req)
|
|
||||||
if isinstance(request, Error):
|
|
||||||
return not_ok(request, context="NodeUpdate")
|
|
||||||
|
|
||||||
answer = check_require_admins(req)
|
|
||||||
if isinstance(answer, Error):
|
|
||||||
return not_ok(answer, context="NodeUpdate")
|
|
||||||
|
|
||||||
node = Node.get_by_machine_id(request.machine_id)
|
|
||||||
if not node:
|
|
||||||
return not_ok(
|
|
||||||
Error(code=ErrorCode.UNABLE_TO_FIND, errors=["unable to find node"]),
|
|
||||||
context=request.machine_id,
|
|
||||||
)
|
|
||||||
if request.debug_keep_node is not None:
|
|
||||||
node.debug_keep_node = request.debug_keep_node
|
|
||||||
|
|
||||||
node.save()
|
|
||||||
return ok(BoolResult(result=True))
|
|
||||||
|
|
||||||
|
|
||||||
def delete(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
request = parse_request(NodeGet, req)
|
|
||||||
if isinstance(request, Error):
|
|
||||||
return not_ok(request, context="NodeDelete")
|
|
||||||
|
|
||||||
answer = check_require_admins(req)
|
|
||||||
if isinstance(answer, Error):
|
|
||||||
return not_ok(answer, context="NodeDelete")
|
|
||||||
|
|
||||||
node = Node.get_by_machine_id(request.machine_id)
|
|
||||||
if not node:
|
|
||||||
return not_ok(
|
|
||||||
Error(code=ErrorCode.UNABLE_TO_FIND, errors=["unable to find node"]),
|
|
||||||
context=request.machine_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
node.set_halt()
|
|
||||||
if node.debug_keep_node:
|
|
||||||
node.debug_keep_node = False
|
|
||||||
node.save()
|
|
||||||
|
|
||||||
return ok(BoolResult(result=True))
|
|
||||||
|
|
||||||
|
|
||||||
def patch(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
request = parse_request(NodeGet, req)
|
|
||||||
if isinstance(request, Error):
|
|
||||||
return not_ok(request, context="NodeReimage")
|
|
||||||
|
|
||||||
answer = check_require_admins(req)
|
|
||||||
if isinstance(answer, Error):
|
|
||||||
return not_ok(answer, context="NodeReimage")
|
|
||||||
|
|
||||||
node = Node.get_by_machine_id(request.machine_id)
|
|
||||||
if not node:
|
|
||||||
return not_ok(
|
|
||||||
Error(code=ErrorCode.UNABLE_TO_FIND, errors=["unable to find node"]),
|
|
||||||
context=request.machine_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
node.stop(done=True)
|
|
||||||
if node.debug_keep_node:
|
|
||||||
node.debug_keep_node = False
|
|
||||||
node.save()
|
|
||||||
return ok(BoolResult(result=True))
|
|
||||||
|
|
||||||
|
|
||||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
methods = {"GET": get, "PATCH": patch, "DELETE": delete, "POST": post}
|
|
||||||
method = methods[req.method]
|
|
||||||
result = call_if_user(req, method)
|
|
||||||
|
|
||||||
return result
|
|
@ -1,22 +0,0 @@
|
|||||||
{
|
|
||||||
"scriptFile": "__init__.py",
|
|
||||||
"bindings": [
|
|
||||||
{
|
|
||||||
"authLevel": "anonymous",
|
|
||||||
"type": "httpTrigger",
|
|
||||||
"direction": "in",
|
|
||||||
"name": "req",
|
|
||||||
"methods": [
|
|
||||||
"get",
|
|
||||||
"patch",
|
|
||||||
"delete",
|
|
||||||
"post"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "http",
|
|
||||||
"direction": "out",
|
|
||||||
"name": "$return"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
@ -1,40 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import azure.functions as func
|
|
||||||
from onefuzztypes.enums import ErrorCode
|
|
||||||
from onefuzztypes.models import Error
|
|
||||||
from onefuzztypes.requests import NodeAddSshKey
|
|
||||||
from onefuzztypes.responses import BoolResult
|
|
||||||
|
|
||||||
from ..onefuzzlib.endpoint_authorization import call_if_user
|
|
||||||
from ..onefuzzlib.request import not_ok, ok, parse_request
|
|
||||||
from ..onefuzzlib.workers.nodes import Node
|
|
||||||
|
|
||||||
|
|
||||||
def post(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
request = parse_request(NodeAddSshKey, req)
|
|
||||||
if isinstance(request, Error):
|
|
||||||
return not_ok(request, context="NodeAddSshKey")
|
|
||||||
|
|
||||||
node = Node.get_by_machine_id(request.machine_id)
|
|
||||||
if not node:
|
|
||||||
return not_ok(
|
|
||||||
Error(code=ErrorCode.UNABLE_TO_FIND, errors=["unable to find node"]),
|
|
||||||
context=request.machine_id,
|
|
||||||
)
|
|
||||||
result = node.add_ssh_public_key(public_key=request.public_key)
|
|
||||||
if isinstance(result, Error):
|
|
||||||
return not_ok(result, context="NodeAddSshKey")
|
|
||||||
|
|
||||||
return ok(BoolResult(result=True))
|
|
||||||
|
|
||||||
|
|
||||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
methods = {"POST": post}
|
|
||||||
method = methods[req.method]
|
|
||||||
result = call_if_user(req, method)
|
|
||||||
|
|
||||||
return result
|
|
@ -1,20 +0,0 @@
|
|||||||
{
|
|
||||||
"scriptFile": "__init__.py",
|
|
||||||
"bindings": [
|
|
||||||
{
|
|
||||||
"authLevel": "anonymous",
|
|
||||||
"type": "httpTrigger",
|
|
||||||
"direction": "in",
|
|
||||||
"name": "req",
|
|
||||||
"methods": [
|
|
||||||
"post"
|
|
||||||
],
|
|
||||||
"route": "node/add_ssh_key"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "http",
|
|
||||||
"direction": "out",
|
|
||||||
"name": "$return"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
@ -1,70 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import azure.functions as func
|
|
||||||
from onefuzztypes.models import Error
|
|
||||||
from onefuzztypes.requests import (
|
|
||||||
NotificationCreate,
|
|
||||||
NotificationGet,
|
|
||||||
NotificationSearch,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ..onefuzzlib.endpoint_authorization import call_if_user
|
|
||||||
from ..onefuzzlib.notifications.main import Notification
|
|
||||||
from ..onefuzzlib.request import not_ok, ok, parse_request
|
|
||||||
|
|
||||||
|
|
||||||
def get(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
logging.info("notification search")
|
|
||||||
request = parse_request(NotificationSearch, req)
|
|
||||||
if isinstance(request, Error):
|
|
||||||
return not_ok(request, context="notification search")
|
|
||||||
|
|
||||||
if request.container:
|
|
||||||
entries = Notification.search(query={"container": request.container})
|
|
||||||
else:
|
|
||||||
entries = Notification.search()
|
|
||||||
|
|
||||||
return ok(entries)
|
|
||||||
|
|
||||||
|
|
||||||
def post(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
logging.info("adding notification hook")
|
|
||||||
request = parse_request(NotificationCreate, req)
|
|
||||||
if isinstance(request, Error):
|
|
||||||
return not_ok(request, context="notification create")
|
|
||||||
|
|
||||||
entry = Notification.create(
|
|
||||||
container=request.container,
|
|
||||||
config=request.config,
|
|
||||||
replace_existing=request.replace_existing,
|
|
||||||
)
|
|
||||||
if isinstance(entry, Error):
|
|
||||||
return not_ok(entry, context="notification create")
|
|
||||||
|
|
||||||
return ok(entry)
|
|
||||||
|
|
||||||
|
|
||||||
def delete(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
request = parse_request(NotificationGet, req)
|
|
||||||
if isinstance(request, Error):
|
|
||||||
return not_ok(request, context="notification delete")
|
|
||||||
|
|
||||||
entry = Notification.get_by_id(request.notification_id)
|
|
||||||
if isinstance(entry, Error):
|
|
||||||
return not_ok(entry, context="notification delete")
|
|
||||||
|
|
||||||
entry.delete()
|
|
||||||
return ok(entry)
|
|
||||||
|
|
||||||
|
|
||||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
|
||||||
methods = {"GET": get, "POST": post, "DELETE": delete}
|
|
||||||
method = methods[req.method]
|
|
||||||
result = call_if_user(req, method)
|
|
||||||
|
|
||||||
return result
|
|
@ -1,21 +0,0 @@
|
|||||||
{
|
|
||||||
"scriptFile": "__init__.py",
|
|
||||||
"bindings": [
|
|
||||||
{
|
|
||||||
"authLevel": "anonymous",
|
|
||||||
"type": "httpTrigger",
|
|
||||||
"direction": "in",
|
|
||||||
"name": "req",
|
|
||||||
"methods": [
|
|
||||||
"get",
|
|
||||||
"post",
|
|
||||||
"delete"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "http",
|
|
||||||
"direction": "out",
|
|
||||||
"name": "$return"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
@ -1,5 +0,0 @@
|
|||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
# pylint: disable=W0612,C0111
|
|
||||||
__version__ = "0.0.0"
|
|
@ -1,271 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import datetime
|
|
||||||
import logging
|
|
||||||
from typing import Optional, cast
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from onefuzztypes.enums import (
|
|
||||||
ErrorCode,
|
|
||||||
NodeState,
|
|
||||||
NodeTaskState,
|
|
||||||
TaskDebugFlag,
|
|
||||||
TaskState,
|
|
||||||
)
|
|
||||||
from onefuzztypes.models import (
|
|
||||||
Error,
|
|
||||||
NodeDoneEventData,
|
|
||||||
NodeSettingUpEventData,
|
|
||||||
NodeStateUpdate,
|
|
||||||
Result,
|
|
||||||
WorkerDoneEvent,
|
|
||||||
WorkerEvent,
|
|
||||||
WorkerRunningEvent,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .task_event import TaskEvent
|
|
||||||
from .tasks.main import Task
|
|
||||||
from .workers.nodes import Node, NodeTasks
|
|
||||||
|
|
||||||
MAX_OUTPUT_SIZE = 4096
|
|
||||||
|
|
||||||
|
|
||||||
def get_node(machine_id: UUID) -> Result[Node]:
|
|
||||||
node = Node.get_by_machine_id(machine_id)
|
|
||||||
if not node:
|
|
||||||
return Error(code=ErrorCode.INVALID_NODE, errors=["unable to find node"])
|
|
||||||
return node
|
|
||||||
|
|
||||||
|
|
||||||
def on_state_update(
|
|
||||||
machine_id: UUID,
|
|
||||||
state_update: NodeStateUpdate,
|
|
||||||
) -> Result[None]:
|
|
||||||
state = state_update.state
|
|
||||||
node = get_node(machine_id)
|
|
||||||
if isinstance(node, Error):
|
|
||||||
if state == NodeState.done:
|
|
||||||
logging.warning(
|
|
||||||
"unable to process state update event. machine_id:"
|
|
||||||
f"{machine_id} state event:{state_update} error:{node}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
return node
|
|
||||||
|
|
||||||
if state == NodeState.free:
|
|
||||||
if node.reimage_requested or node.delete_requested:
|
|
||||||
logging.info("stopping free node with reset flags: %s", node.machine_id)
|
|
||||||
node.stop()
|
|
||||||
return None
|
|
||||||
|
|
||||||
if node.could_shrink_scaleset():
|
|
||||||
logging.info("stopping free node to resize scaleset: %s", node.machine_id)
|
|
||||||
node.set_halt()
|
|
||||||
return None
|
|
||||||
|
|
||||||
if state == NodeState.init:
|
|
||||||
if node.delete_requested:
|
|
||||||
logging.info("stopping node (init and delete_requested): %s", machine_id)
|
|
||||||
node.stop()
|
|
||||||
return None
|
|
||||||
|
|
||||||
# not checking reimage_requested, as nodes only send 'init' state once. If
|
|
||||||
# they send 'init' with reimage_requested, it's because the node was reimaged
|
|
||||||
# successfully.
|
|
||||||
node.reimage_requested = False
|
|
||||||
node.initialized_at = datetime.datetime.now(datetime.timezone.utc)
|
|
||||||
node.set_state(state)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
logging.info("node state update: %s from:%s to:%s", machine_id, node.state, state)
|
|
||||||
node.set_state(state)
|
|
||||||
|
|
||||||
if state == NodeState.free:
|
|
||||||
logging.info("node now available for work: %s", machine_id)
|
|
||||||
elif state == NodeState.setting_up:
|
|
||||||
# Model-validated.
|
|
||||||
#
|
|
||||||
# This field will be required in the future.
|
|
||||||
# For now, it is optional for back compat.
|
|
||||||
setting_up_data = cast(
|
|
||||||
Optional[NodeSettingUpEventData],
|
|
||||||
state_update.data,
|
|
||||||
)
|
|
||||||
|
|
||||||
if setting_up_data:
|
|
||||||
if not setting_up_data.tasks:
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.INVALID_REQUEST,
|
|
||||||
errors=["setup without tasks. machine_id: %s", str(machine_id)],
|
|
||||||
)
|
|
||||||
|
|
||||||
for task_id in setting_up_data.tasks:
|
|
||||||
task = Task.get_by_task_id(task_id)
|
|
||||||
if isinstance(task, Error):
|
|
||||||
return task
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
"node starting task. machine_id: %s job_id: %s task_id: %s",
|
|
||||||
machine_id,
|
|
||||||
task.job_id,
|
|
||||||
task.task_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# The task state may be `running` if it has `vm_count` > 1, and
|
|
||||||
# another node is concurrently executing the task. If so, leave
|
|
||||||
# the state as-is, to represent the max progress made.
|
|
||||||
#
|
|
||||||
# Other states we would want to preserve are excluded by the
|
|
||||||
# outermost conditional check.
|
|
||||||
if task.state not in [TaskState.running, TaskState.setting_up]:
|
|
||||||
task.set_state(TaskState.setting_up)
|
|
||||||
|
|
||||||
# Note: we set the node task state to `setting_up`, even though
|
|
||||||
# the task itself may be `running`.
|
|
||||||
node_task = NodeTasks(
|
|
||||||
machine_id=machine_id,
|
|
||||||
task_id=task_id,
|
|
||||||
state=NodeTaskState.setting_up,
|
|
||||||
)
|
|
||||||
node_task.save()
|
|
||||||
|
|
||||||
elif state == NodeState.done:
|
|
||||||
# Model-validated.
|
|
||||||
#
|
|
||||||
# This field will be required in the future.
|
|
||||||
# For now, it is optional for back compat.
|
|
||||||
done_data = cast(Optional[NodeDoneEventData], state_update.data)
|
|
||||||
error = None
|
|
||||||
if done_data:
|
|
||||||
if done_data.error:
|
|
||||||
error_text = done_data.json(exclude_none=True)
|
|
||||||
error = Error(
|
|
||||||
code=ErrorCode.TASK_FAILED,
|
|
||||||
errors=[error_text],
|
|
||||||
)
|
|
||||||
logging.error(
|
|
||||||
"node 'done' with error: machine_id:%s, data:%s",
|
|
||||||
machine_id,
|
|
||||||
error_text,
|
|
||||||
)
|
|
||||||
|
|
||||||
# if tasks are running on the node when it reports as Done
|
|
||||||
# those are stopped early
|
|
||||||
node.mark_tasks_stopped_early(error=error)
|
|
||||||
node.to_reimage(done=True)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def on_worker_event_running(
|
|
||||||
machine_id: UUID, event: WorkerRunningEvent
|
|
||||||
) -> Result[None]:
|
|
||||||
task = Task.get_by_task_id(event.task_id)
|
|
||||||
if isinstance(task, Error):
|
|
||||||
return task
|
|
||||||
|
|
||||||
node = get_node(machine_id)
|
|
||||||
if isinstance(node, Error):
|
|
||||||
return node
|
|
||||||
|
|
||||||
if node.state not in NodeState.ready_for_reset():
|
|
||||||
node.set_state(NodeState.busy)
|
|
||||||
|
|
||||||
node_task = NodeTasks(
|
|
||||||
machine_id=machine_id, task_id=event.task_id, state=NodeTaskState.running
|
|
||||||
)
|
|
||||||
node_task.save()
|
|
||||||
|
|
||||||
if task.state in TaskState.shutting_down():
|
|
||||||
logging.info(
|
|
||||||
"ignoring task start from node. "
|
|
||||||
"machine_id:%s job_id:%s task_id:%s (state: %s)",
|
|
||||||
machine_id,
|
|
||||||
task.job_id,
|
|
||||||
task.task_id,
|
|
||||||
task.state,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
"task started on node. machine_id:%s job_id%s task_id:%s",
|
|
||||||
machine_id,
|
|
||||||
task.job_id,
|
|
||||||
task.task_id,
|
|
||||||
)
|
|
||||||
task.set_state(TaskState.running)
|
|
||||||
|
|
||||||
task_event = TaskEvent(
|
|
||||||
task_id=task.task_id,
|
|
||||||
machine_id=machine_id,
|
|
||||||
event_data=WorkerEvent(running=event),
|
|
||||||
)
|
|
||||||
task_event.save()
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def on_worker_event_done(machine_id: UUID, event: WorkerDoneEvent) -> Result[None]:
|
|
||||||
task = Task.get_by_task_id(event.task_id)
|
|
||||||
if isinstance(task, Error):
|
|
||||||
return task
|
|
||||||
|
|
||||||
node = get_node(machine_id)
|
|
||||||
if isinstance(node, Error):
|
|
||||||
return node
|
|
||||||
|
|
||||||
if event.exit_status.success:
|
|
||||||
logging.info(
|
|
||||||
"task done. %s:%s status:%s", task.job_id, task.task_id, event.exit_status
|
|
||||||
)
|
|
||||||
task.mark_stopping()
|
|
||||||
if (
|
|
||||||
task.config.debug
|
|
||||||
and TaskDebugFlag.keep_node_on_completion in task.config.debug
|
|
||||||
):
|
|
||||||
node.debug_keep_node = True
|
|
||||||
node.save()
|
|
||||||
else:
|
|
||||||
task.mark_failed(
|
|
||||||
Error(
|
|
||||||
code=ErrorCode.TASK_FAILED,
|
|
||||||
errors=[
|
|
||||||
"task failed. exit_status:%s" % event.exit_status,
|
|
||||||
event.stdout[-MAX_OUTPUT_SIZE:],
|
|
||||||
event.stderr[-MAX_OUTPUT_SIZE:],
|
|
||||||
],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if task.config.debug and (
|
|
||||||
TaskDebugFlag.keep_node_on_failure in task.config.debug
|
|
||||||
or TaskDebugFlag.keep_node_on_completion in task.config.debug
|
|
||||||
):
|
|
||||||
node.debug_keep_node = True
|
|
||||||
node.save()
|
|
||||||
|
|
||||||
if not node.debug_keep_node:
|
|
||||||
node_task = NodeTasks.get(machine_id, event.task_id)
|
|
||||||
if node_task:
|
|
||||||
node_task.delete()
|
|
||||||
|
|
||||||
event.stdout = event.stdout[-MAX_OUTPUT_SIZE:]
|
|
||||||
event.stderr = event.stderr[-MAX_OUTPUT_SIZE:]
|
|
||||||
task_event = TaskEvent(
|
|
||||||
task_id=task.task_id, machine_id=machine_id, event_data=WorkerEvent(done=event)
|
|
||||||
)
|
|
||||||
task_event.save()
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def on_worker_event(machine_id: UUID, event: WorkerEvent) -> Result[None]:
|
|
||||||
if event.running:
|
|
||||||
return on_worker_event_running(machine_id, event.running)
|
|
||||||
elif event.done:
|
|
||||||
return on_worker_event_done(machine_id, event.done)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
@ -1,171 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from onefuzztypes.enums import NodeState, ScalesetState
|
|
||||||
from onefuzztypes.models import AutoScaleConfig, TaskPool
|
|
||||||
|
|
||||||
from .tasks.main import Task
|
|
||||||
from .workers.nodes import Node
|
|
||||||
from .workers.pools import Pool
|
|
||||||
from .workers.scalesets import Scaleset
|
|
||||||
|
|
||||||
|
|
||||||
def scale_up(pool: Pool, scalesets: List[Scaleset], nodes_needed: int) -> None:
|
|
||||||
logging.info("Scaling up")
|
|
||||||
autoscale_config = pool.autoscale
|
|
||||||
if not isinstance(autoscale_config, AutoScaleConfig):
|
|
||||||
return
|
|
||||||
|
|
||||||
for scaleset in scalesets:
|
|
||||||
if scaleset.state in [ScalesetState.running, ScalesetState.resize]:
|
|
||||||
|
|
||||||
max_size = min(scaleset.max_size(), autoscale_config.scaleset_size)
|
|
||||||
logging.info(
|
|
||||||
"scaleset:%s size:%d max_size:%d",
|
|
||||||
scaleset.scaleset_id,
|
|
||||||
scaleset.size,
|
|
||||||
max_size,
|
|
||||||
)
|
|
||||||
if scaleset.size < max_size:
|
|
||||||
current_size = scaleset.size
|
|
||||||
if nodes_needed <= max_size - current_size:
|
|
||||||
scaleset.set_size(current_size + nodes_needed)
|
|
||||||
nodes_needed = 0
|
|
||||||
else:
|
|
||||||
scaleset.set_size(max_size)
|
|
||||||
nodes_needed = nodes_needed - (max_size - current_size)
|
|
||||||
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if nodes_needed == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
for _ in range(
|
|
||||||
math.ceil(
|
|
||||||
nodes_needed
|
|
||||||
/ min(
|
|
||||||
Scaleset.scaleset_max_size(autoscale_config.image),
|
|
||||||
autoscale_config.scaleset_size,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
):
|
|
||||||
logging.info("Creating Scaleset for Pool %s", pool.name)
|
|
||||||
max_nodes_scaleset = min(
|
|
||||||
Scaleset.scaleset_max_size(autoscale_config.image),
|
|
||||||
autoscale_config.scaleset_size,
|
|
||||||
nodes_needed,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not autoscale_config.region:
|
|
||||||
raise Exception("Region is missing")
|
|
||||||
|
|
||||||
Scaleset.create(
|
|
||||||
pool_name=pool.name,
|
|
||||||
vm_sku=autoscale_config.vm_sku,
|
|
||||||
image=autoscale_config.image,
|
|
||||||
region=autoscale_config.region,
|
|
||||||
size=max_nodes_scaleset,
|
|
||||||
spot_instances=autoscale_config.spot_instances,
|
|
||||||
ephemeral_os_disks=autoscale_config.ephemeral_os_disks,
|
|
||||||
tags={"pool": pool.name},
|
|
||||||
)
|
|
||||||
nodes_needed -= max_nodes_scaleset
|
|
||||||
|
|
||||||
|
|
||||||
def scale_down(scalesets: List[Scaleset], nodes_to_remove: int) -> None:
|
|
||||||
logging.info("Scaling down")
|
|
||||||
for scaleset in scalesets:
|
|
||||||
num_of_nodes = len(Node.search_states(scaleset_id=scaleset.scaleset_id))
|
|
||||||
if scaleset.size != num_of_nodes and scaleset.state not in [
|
|
||||||
ScalesetState.resize,
|
|
||||||
ScalesetState.shutdown,
|
|
||||||
ScalesetState.halt,
|
|
||||||
]:
|
|
||||||
scaleset.set_state(ScalesetState.resize)
|
|
||||||
|
|
||||||
free_nodes = Node.search_states(
|
|
||||||
scaleset_id=scaleset.scaleset_id,
|
|
||||||
states=[NodeState.free],
|
|
||||||
)
|
|
||||||
nodes = []
|
|
||||||
for node in free_nodes:
|
|
||||||
if not node.delete_requested:
|
|
||||||
nodes.append(node)
|
|
||||||
logging.info("Scaleset: %s, #Free Nodes: %s", scaleset.scaleset_id, len(nodes))
|
|
||||||
|
|
||||||
if nodes and nodes_to_remove > 0:
|
|
||||||
max_nodes_remove = min(len(nodes), nodes_to_remove)
|
|
||||||
# All nodes in scaleset are free. Can shutdown VMSS
|
|
||||||
if max_nodes_remove >= scaleset.size and len(nodes) >= scaleset.size:
|
|
||||||
scaleset.set_state(ScalesetState.shutdown)
|
|
||||||
nodes_to_remove = nodes_to_remove - scaleset.size
|
|
||||||
for node in nodes:
|
|
||||||
node.set_shutdown()
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Resize of VMSS needed
|
|
||||||
scaleset.set_size(scaleset.size - max_nodes_remove)
|
|
||||||
nodes_to_remove = nodes_to_remove - max_nodes_remove
|
|
||||||
scaleset.set_state(ScalesetState.resize)
|
|
||||||
|
|
||||||
|
|
||||||
def get_vm_count(tasks: List[Task]) -> int:
|
|
||||||
count = 0
|
|
||||||
for task in tasks:
|
|
||||||
task_pool = task.get_pool()
|
|
||||||
if (
|
|
||||||
not task_pool
|
|
||||||
or not isinstance(task_pool, Pool)
|
|
||||||
or not isinstance(task.config.pool, TaskPool)
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
count += task.config.pool.count
|
|
||||||
return count
|
|
||||||
|
|
||||||
|
|
||||||
def autoscale_pool(pool: Pool) -> None:
|
|
||||||
logging.info("autoscale: %s", pool.autoscale)
|
|
||||||
if not pool.autoscale:
|
|
||||||
return
|
|
||||||
|
|
||||||
# get all the tasks (count not stopped) for the pool
|
|
||||||
tasks = Task.get_tasks_by_pool_name(pool.name)
|
|
||||||
logging.info("Pool: %s, #Tasks %d", pool.name, len(tasks))
|
|
||||||
|
|
||||||
num_of_tasks = get_vm_count(tasks)
|
|
||||||
nodes_needed = max(num_of_tasks, pool.autoscale.min_size)
|
|
||||||
if pool.autoscale.max_size:
|
|
||||||
nodes_needed = min(nodes_needed, pool.autoscale.max_size)
|
|
||||||
|
|
||||||
# do scaleset logic match with pool
|
|
||||||
# get all the scalesets for the pool
|
|
||||||
scalesets = Scaleset.search_by_pool(pool.name)
|
|
||||||
pool_resize = False
|
|
||||||
for scaleset in scalesets:
|
|
||||||
if scaleset.state in ScalesetState.modifying():
|
|
||||||
pool_resize = True
|
|
||||||
break
|
|
||||||
nodes_needed = nodes_needed - scaleset.size
|
|
||||||
|
|
||||||
if pool_resize:
|
|
||||||
return
|
|
||||||
|
|
||||||
logging.info("Pool: %s, #Nodes Needed: %d", pool.name, nodes_needed)
|
|
||||||
if nodes_needed > 0:
|
|
||||||
# resizing scaleset or creating new scaleset.
|
|
||||||
scale_up(pool, scalesets, nodes_needed)
|
|
||||||
elif nodes_needed < 0:
|
|
||||||
for scaleset in scalesets:
|
|
||||||
nodes = Node.search_states(scaleset_id=scaleset.scaleset_id)
|
|
||||||
for node in nodes:
|
|
||||||
if node.delete_requested:
|
|
||||||
nodes_needed += 1
|
|
||||||
if nodes_needed < 0:
|
|
||||||
scale_down(scalesets, abs(nodes_needed))
|
|
@ -1,37 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import os
|
|
||||||
import subprocess # nosec - used for ssh key generation
|
|
||||||
import tempfile
|
|
||||||
from typing import Tuple
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
from onefuzztypes.models import Authentication
|
|
||||||
|
|
||||||
|
|
||||||
def generate_keypair() -> Tuple[str, str]:
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
|
||||||
filename = os.path.join(tmpdir, "key")
|
|
||||||
|
|
||||||
cmd = ["ssh-keygen", "-t", "rsa", "-f", filename, "-P", "", "-b", "2048"]
|
|
||||||
subprocess.check_output(cmd) # nosec - all arguments are under our control
|
|
||||||
|
|
||||||
with open(filename, "r") as handle:
|
|
||||||
private = handle.read()
|
|
||||||
|
|
||||||
with open(filename + ".pub", "r") as handle:
|
|
||||||
public = handle.read().strip()
|
|
||||||
|
|
||||||
return (public, private)
|
|
||||||
|
|
||||||
|
|
||||||
def build_auth() -> Authentication:
|
|
||||||
public_key, private_key = generate_keypair()
|
|
||||||
auth = Authentication(
|
|
||||||
password=str(uuid4()), public_key=public_key, private_key=private_key
|
|
||||||
)
|
|
||||||
return auth
|
|
@ -1,328 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
from datetime import timedelta
|
|
||||||
from typing import Any, Dict, Optional, Union
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from azure.core.exceptions import ResourceNotFoundError
|
|
||||||
from azure.mgmt.monitor.models import (
|
|
||||||
AutoscaleProfile,
|
|
||||||
AutoscaleSettingResource,
|
|
||||||
ComparisonOperationType,
|
|
||||||
DiagnosticSettingsResource,
|
|
||||||
LogSettings,
|
|
||||||
MetricStatisticType,
|
|
||||||
MetricTrigger,
|
|
||||||
RetentionPolicy,
|
|
||||||
ScaleAction,
|
|
||||||
ScaleCapacity,
|
|
||||||
ScaleDirection,
|
|
||||||
ScaleRule,
|
|
||||||
ScaleType,
|
|
||||||
TimeAggregationType,
|
|
||||||
)
|
|
||||||
from msrestazure.azure_exceptions import CloudError
|
|
||||||
from onefuzztypes.enums import ErrorCode
|
|
||||||
from onefuzztypes.models import Error
|
|
||||||
from onefuzztypes.primitives import Region
|
|
||||||
|
|
||||||
from .creds import (
|
|
||||||
get_base_region,
|
|
||||||
get_base_resource_group,
|
|
||||||
get_subscription,
|
|
||||||
retry_on_auth_failure,
|
|
||||||
)
|
|
||||||
from .log_analytics import get_workspace_id
|
|
||||||
from .monitor import get_monitor_client
|
|
||||||
|
|
||||||
|
|
||||||
@retry_on_auth_failure()
|
|
||||||
def get_auto_scale_settings(
|
|
||||||
vmss: UUID,
|
|
||||||
) -> Union[Optional[AutoscaleSettingResource], Error]:
|
|
||||||
logging.info("Getting auto scale settings for %s" % vmss)
|
|
||||||
client = get_monitor_client()
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
|
|
||||||
try:
|
|
||||||
auto_scale_collections = client.autoscale_settings.list_by_resource_group(
|
|
||||||
resource_group
|
|
||||||
)
|
|
||||||
for auto_scale in auto_scale_collections:
|
|
||||||
if str(auto_scale.target_resource_uri).endswith(str(vmss)):
|
|
||||||
logging.info("Found auto scale settings for %s" % vmss)
|
|
||||||
return auto_scale
|
|
||||||
|
|
||||||
except (ResourceNotFoundError, CloudError):
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.INVALID_CONFIGURATION,
|
|
||||||
errors=[
|
|
||||||
"Failed to check if scaleset %s already has an autoscale resource"
|
|
||||||
% vmss
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@retry_on_auth_failure()
|
|
||||||
def add_auto_scale_to_vmss(
|
|
||||||
vmss: UUID, auto_scale_profile: AutoscaleProfile
|
|
||||||
) -> Optional[Error]:
|
|
||||||
logging.info("Checking scaleset %s for existing auto scale resources" % vmss)
|
|
||||||
|
|
||||||
existing_auto_scale_resource = get_auto_scale_settings(vmss)
|
|
||||||
|
|
||||||
if isinstance(existing_auto_scale_resource, Error):
|
|
||||||
return existing_auto_scale_resource
|
|
||||||
|
|
||||||
if existing_auto_scale_resource is not None:
|
|
||||||
logging.warning("Scaleset %s already has auto scale resource" % vmss)
|
|
||||||
return None
|
|
||||||
|
|
||||||
auto_scale_resource = create_auto_scale_resource_for(
|
|
||||||
vmss, get_base_region(), auto_scale_profile
|
|
||||||
)
|
|
||||||
if isinstance(auto_scale_resource, Error):
|
|
||||||
return auto_scale_resource
|
|
||||||
|
|
||||||
diagnostics_resource = setup_auto_scale_diagnostics(
|
|
||||||
auto_scale_resource.id, auto_scale_resource.name, get_workspace_id()
|
|
||||||
)
|
|
||||||
if isinstance(diagnostics_resource, Error):
|
|
||||||
return diagnostics_resource
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def update_auto_scale(auto_scale_resource: AutoscaleSettingResource) -> Optional[Error]:
|
|
||||||
logging.info("Updating auto scale resource: %s" % auto_scale_resource.name)
|
|
||||||
client = get_monitor_client()
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
|
|
||||||
try:
|
|
||||||
auto_scale_resource = client.autoscale_settings.create_or_update(
|
|
||||||
resource_group, auto_scale_resource.name, auto_scale_resource
|
|
||||||
)
|
|
||||||
logging.info(
|
|
||||||
"Successfully updated auto scale resource: %s" % auto_scale_resource.name
|
|
||||||
)
|
|
||||||
except (ResourceNotFoundError, CloudError):
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_UPDATE,
|
|
||||||
errors=[
|
|
||||||
"unable to update auto scale resource with name: %s and profile: %s"
|
|
||||||
% (auto_scale_resource.name, auto_scale_resource)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def create_auto_scale_resource_for(
|
|
||||||
resource_id: UUID, location: Region, profile: AutoscaleProfile
|
|
||||||
) -> Union[AutoscaleSettingResource, Error]:
|
|
||||||
logging.info("Creating auto scale resource for: %s" % resource_id)
|
|
||||||
client = get_monitor_client()
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
subscription = get_subscription()
|
|
||||||
|
|
||||||
scaleset_uri = (
|
|
||||||
"/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachineScaleSets/%s" # noqa: E501
|
|
||||||
% (subscription, resource_group, resource_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
params: Dict[str, Any] = {
|
|
||||||
"location": location,
|
|
||||||
"profiles": [profile],
|
|
||||||
"target_resource_uri": scaleset_uri,
|
|
||||||
"enabled": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
auto_scale_resource = client.autoscale_settings.create_or_update(
|
|
||||||
resource_group, str(uuid.uuid4()), params
|
|
||||||
)
|
|
||||||
logging.info(
|
|
||||||
"Successfully created auto scale resource %s for %s"
|
|
||||||
% (auto_scale_resource.id, resource_id)
|
|
||||||
)
|
|
||||||
return auto_scale_resource
|
|
||||||
except (ResourceNotFoundError, CloudError):
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_CREATE,
|
|
||||||
errors=[
|
|
||||||
"unable to create auto scale resource for resource: %s with profile: %s"
|
|
||||||
% (resource_id, profile)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_auto_scale_profile(
|
|
||||||
queue_uri: str,
|
|
||||||
min: int,
|
|
||||||
max: int,
|
|
||||||
default: int,
|
|
||||||
scale_out_amount: int,
|
|
||||||
scale_out_cooldown: int,
|
|
||||||
scale_in_amount: int,
|
|
||||||
scale_in_cooldown: int,
|
|
||||||
) -> AutoscaleProfile:
|
|
||||||
return AutoscaleProfile(
|
|
||||||
name=str(uuid.uuid4()),
|
|
||||||
capacity=ScaleCapacity(minimum=min, maximum=max, default=max),
|
|
||||||
# Auto scale tuning guidance:
|
|
||||||
# https://docs.microsoft.com/en-us/azure/architecture/best-practices/auto-scaling
|
|
||||||
rules=[
|
|
||||||
ScaleRule(
|
|
||||||
metric_trigger=MetricTrigger(
|
|
||||||
metric_name="ApproximateMessageCount",
|
|
||||||
metric_resource_uri=queue_uri,
|
|
||||||
# Check every 15 minutes
|
|
||||||
time_grain=timedelta(minutes=15),
|
|
||||||
# The average amount of messages there are in the pool queue
|
|
||||||
time_aggregation=TimeAggregationType.AVERAGE,
|
|
||||||
statistic=MetricStatisticType.COUNT,
|
|
||||||
# Over the past 15 minutes
|
|
||||||
time_window=timedelta(minutes=15),
|
|
||||||
# When there's more than 1 message in the pool queue
|
|
||||||
operator=ComparisonOperationType.GREATER_THAN_OR_EQUAL,
|
|
||||||
threshold=1,
|
|
||||||
divide_per_instance=False,
|
|
||||||
),
|
|
||||||
scale_action=ScaleAction(
|
|
||||||
direction=ScaleDirection.INCREASE,
|
|
||||||
type=ScaleType.CHANGE_COUNT,
|
|
||||||
value=scale_out_amount,
|
|
||||||
cooldown=timedelta(minutes=scale_out_cooldown),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
# Scale in
|
|
||||||
ScaleRule(
|
|
||||||
# Scale in if no work in the past 20 mins
|
|
||||||
metric_trigger=MetricTrigger(
|
|
||||||
metric_name="ApproximateMessageCount",
|
|
||||||
metric_resource_uri=queue_uri,
|
|
||||||
# Check every 10 minutes
|
|
||||||
time_grain=timedelta(minutes=10),
|
|
||||||
# The average amount of messages there are in the pool queue
|
|
||||||
time_aggregation=TimeAggregationType.AVERAGE,
|
|
||||||
statistic=MetricStatisticType.SUM,
|
|
||||||
# Over the past 10 minutes
|
|
||||||
time_window=timedelta(minutes=10),
|
|
||||||
# When there's no messages in the pool queue
|
|
||||||
operator=ComparisonOperationType.EQUALS,
|
|
||||||
threshold=0,
|
|
||||||
divide_per_instance=False,
|
|
||||||
),
|
|
||||||
scale_action=ScaleAction(
|
|
||||||
direction=ScaleDirection.DECREASE,
|
|
||||||
type=ScaleType.CHANGE_COUNT,
|
|
||||||
value=scale_in_amount,
|
|
||||||
cooldown=timedelta(minutes=scale_in_cooldown),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_auto_scale_profile(scaleset_id: UUID) -> AutoscaleProfile:
|
|
||||||
logging.info("Getting scaleset %s for existing auto scale resources" % scaleset_id)
|
|
||||||
client = get_monitor_client()
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
|
|
||||||
auto_scale_resource = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
auto_scale_collections = client.autoscale_settings.list_by_resource_group(
|
|
||||||
resource_group
|
|
||||||
)
|
|
||||||
for auto_scale in auto_scale_collections:
|
|
||||||
if str(auto_scale.target_resource_uri).endswith(str(scaleset_id)):
|
|
||||||
auto_scale_resource = auto_scale
|
|
||||||
auto_scale_profiles = auto_scale_resource.profiles
|
|
||||||
if len(auto_scale_profiles) != 1:
|
|
||||||
logging.info(
|
|
||||||
"Found more than one autoscaling profile for scaleset %s"
|
|
||||||
% scaleset_id
|
|
||||||
)
|
|
||||||
return auto_scale_profiles[0]
|
|
||||||
|
|
||||||
except (ResourceNotFoundError, CloudError):
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.INVALID_CONFIGURATION,
|
|
||||||
errors=["Failed to query scaleset %s autoscale resource" % scaleset_id],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def default_auto_scale_profile(queue_uri: str, scaleset_size: int) -> AutoscaleProfile:
|
|
||||||
return create_auto_scale_profile(
|
|
||||||
queue_uri, 1, scaleset_size, scaleset_size, 1, 10, 1, 5
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def shutdown_scaleset_rule(queue_uri: str) -> ScaleRule:
|
|
||||||
return ScaleRule(
|
|
||||||
# Scale in if there are 0 or more messages in the queue (aka: every time)
|
|
||||||
metric_trigger=MetricTrigger(
|
|
||||||
metric_name="ApproximateMessageCount",
|
|
||||||
metric_resource_uri=queue_uri,
|
|
||||||
# Check every 10 minutes
|
|
||||||
time_grain=timedelta(minutes=5),
|
|
||||||
# The average amount of messages there are in the pool queue
|
|
||||||
time_aggregation=TimeAggregationType.AVERAGE,
|
|
||||||
statistic=MetricStatisticType.SUM,
|
|
||||||
# Over the past 10 minutes
|
|
||||||
time_window=timedelta(minutes=5),
|
|
||||||
operator=ComparisonOperationType.GREATER_THAN_OR_EQUAL,
|
|
||||||
threshold=0,
|
|
||||||
divide_per_instance=False,
|
|
||||||
),
|
|
||||||
scale_action=ScaleAction(
|
|
||||||
direction=ScaleDirection.DECREASE,
|
|
||||||
type=ScaleType.CHANGE_COUNT,
|
|
||||||
value=1,
|
|
||||||
cooldown=timedelta(minutes=5),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def setup_auto_scale_diagnostics(
|
|
||||||
auto_scale_resource_uri: str,
|
|
||||||
auto_scale_resource_name: str,
|
|
||||||
log_analytics_workspace_id: str,
|
|
||||||
) -> Union[DiagnosticSettingsResource, Error]:
|
|
||||||
logging.info("Setting up diagnostics for auto scale")
|
|
||||||
client = get_monitor_client()
|
|
||||||
|
|
||||||
log_settings = LogSettings(
|
|
||||||
enabled=True,
|
|
||||||
category_group="allLogs",
|
|
||||||
retention_policy=RetentionPolicy(enabled=True, days=30),
|
|
||||||
)
|
|
||||||
|
|
||||||
params: Dict[str, Any] = {
|
|
||||||
"logs": [log_settings],
|
|
||||||
"workspace_id": log_analytics_workspace_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
diagnostics = client.diagnostic_settings.create_or_update(
|
|
||||||
auto_scale_resource_uri, "%s-diagnostics" % auto_scale_resource_name, params
|
|
||||||
)
|
|
||||||
logging.info(
|
|
||||||
"Diagnostics created for auto scale resource: %s" % auto_scale_resource_uri
|
|
||||||
)
|
|
||||||
return diagnostics
|
|
||||||
except (ResourceNotFoundError, CloudError):
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_CREATE,
|
|
||||||
errors=[
|
|
||||||
"unable to setup diagnostics for auto scale resource: %s"
|
|
||||||
% (auto_scale_resource_uri)
|
|
||||||
],
|
|
||||||
)
|
|
@ -1,9 +0,0 @@
|
|||||||
from azure.mgmt.compute import ComputeManagementClient
|
|
||||||
from memoization import cached
|
|
||||||
|
|
||||||
from .creds import get_identity, get_subscription
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def get_compute_client() -> ComputeManagementClient:
|
|
||||||
return ComputeManagementClient(get_identity(), get_subscription())
|
|
@ -1,388 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import datetime
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import urllib.parse
|
|
||||||
from typing import Dict, Optional, Tuple, Union, cast
|
|
||||||
|
|
||||||
from azure.common import AzureHttpError, AzureMissingResourceHttpError
|
|
||||||
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
|
|
||||||
from azure.storage.blob import (
|
|
||||||
BlobClient,
|
|
||||||
BlobSasPermissions,
|
|
||||||
BlobServiceClient,
|
|
||||||
ContainerClient,
|
|
||||||
ContainerSasPermissions,
|
|
||||||
generate_blob_sas,
|
|
||||||
generate_container_sas,
|
|
||||||
)
|
|
||||||
from memoization import cached
|
|
||||||
from onefuzztypes.primitives import Container
|
|
||||||
|
|
||||||
from .storage import (
|
|
||||||
StorageType,
|
|
||||||
choose_account,
|
|
||||||
get_accounts,
|
|
||||||
get_storage_account_name_key,
|
|
||||||
get_storage_account_name_key_by_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
CONTAINER_SAS_DEFAULT_DURATION = datetime.timedelta(days=30)
|
|
||||||
|
|
||||||
|
|
||||||
def get_url(account_name: str) -> str:
|
|
||||||
return f"https://{account_name}.blob.core.windows.net/"
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def get_blob_service(account_id: str) -> BlobServiceClient:
|
|
||||||
logging.debug("getting blob container (account_id: %s)", account_id)
|
|
||||||
account_name, account_key = get_storage_account_name_key(account_id)
|
|
||||||
account_url = get_url(account_name)
|
|
||||||
service = BlobServiceClient(account_url=account_url, credential=account_key)
|
|
||||||
return service
|
|
||||||
|
|
||||||
|
|
||||||
def container_metadata(
|
|
||||||
container: Container, account_id: str
|
|
||||||
) -> Optional[Dict[str, str]]:
|
|
||||||
try:
|
|
||||||
result = (
|
|
||||||
get_blob_service(account_id)
|
|
||||||
.get_container_client(container)
|
|
||||||
.get_container_properties()
|
|
||||||
)
|
|
||||||
return cast(Dict[str, str], result)
|
|
||||||
except AzureHttpError:
|
|
||||||
pass
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def find_container(
|
|
||||||
container: Container, storage_type: StorageType
|
|
||||||
) -> Optional[ContainerClient]:
|
|
||||||
accounts = get_accounts(storage_type)
|
|
||||||
|
|
||||||
# check secondary accounts first by searching in reverse.
|
|
||||||
#
|
|
||||||
# By implementation, the primary account is specified first, followed by
|
|
||||||
# any secondary accounts.
|
|
||||||
#
|
|
||||||
# Secondary accounts, if they exist, are preferred for containers and have
|
|
||||||
# increased IOP rates, this should be a slight optimization
|
|
||||||
for account in reversed(accounts):
|
|
||||||
client = get_blob_service(account).get_container_client(container)
|
|
||||||
if client.exists():
|
|
||||||
return client
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def container_exists(container: Container, storage_type: StorageType) -> bool:
|
|
||||||
return find_container(container, storage_type) is not None
|
|
||||||
|
|
||||||
|
|
||||||
def get_containers(storage_type: StorageType) -> Dict[str, Dict[str, str]]:
|
|
||||||
containers: Dict[str, Dict[str, str]] = {}
|
|
||||||
|
|
||||||
for account_id in get_accounts(storage_type):
|
|
||||||
containers.update(
|
|
||||||
{
|
|
||||||
x.name: x.metadata
|
|
||||||
for x in get_blob_service(account_id).list_containers(
|
|
||||||
include_metadata=True
|
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return containers
|
|
||||||
|
|
||||||
|
|
||||||
def get_container_metadata(
|
|
||||||
container: Container, storage_type: StorageType
|
|
||||||
) -> Optional[Dict[str, str]]:
|
|
||||||
client = find_container(container, storage_type)
|
|
||||||
if client is None:
|
|
||||||
return None
|
|
||||||
result = client.get_container_properties().metadata
|
|
||||||
return cast(Dict[str, str], result)
|
|
||||||
|
|
||||||
|
|
||||||
def add_container_sas_url(
|
|
||||||
container_url: str, duration: datetime.timedelta = CONTAINER_SAS_DEFAULT_DURATION
|
|
||||||
) -> str:
|
|
||||||
parsed = urllib.parse.urlparse(container_url)
|
|
||||||
query = urllib.parse.parse_qs(parsed.query)
|
|
||||||
if "sig" in query:
|
|
||||||
return container_url
|
|
||||||
else:
|
|
||||||
start, expiry = sas_time_window(duration)
|
|
||||||
account_name = parsed.netloc.split(".")[0]
|
|
||||||
account_key = get_storage_account_name_key_by_name(account_name)
|
|
||||||
sas_token = generate_container_sas(
|
|
||||||
account_name=account_name,
|
|
||||||
container_name=parsed.path.split("/")[1],
|
|
||||||
account_key=account_key,
|
|
||||||
permission=ContainerSasPermissions(
|
|
||||||
read=True, write=True, delete=True, list=True
|
|
||||||
),
|
|
||||||
expiry=expiry,
|
|
||||||
start=start,
|
|
||||||
)
|
|
||||||
return f"{container_url}?{sas_token}"
|
|
||||||
|
|
||||||
|
|
||||||
def get_or_create_container_client(
|
|
||||||
container: Container,
|
|
||||||
storage_type: StorageType,
|
|
||||||
metadata: Optional[Dict[str, str]],
|
|
||||||
) -> Optional[ContainerClient]:
|
|
||||||
client = find_container(container, storage_type)
|
|
||||||
if client is None:
|
|
||||||
account = choose_account(storage_type)
|
|
||||||
client = get_blob_service(account).get_container_client(container)
|
|
||||||
try:
|
|
||||||
client.create_container(metadata=metadata)
|
|
||||||
except (ResourceExistsError, AzureHttpError) as err:
|
|
||||||
# note: resource exists error happens during creation if the container
|
|
||||||
# is being deleted
|
|
||||||
logging.error(
|
|
||||||
(
|
|
||||||
"unable to create container. account: %s "
|
|
||||||
"container: %s metadata: %s - %s"
|
|
||||||
),
|
|
||||||
account,
|
|
||||||
container,
|
|
||||||
metadata,
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
def create_container(
|
|
||||||
container: Container,
|
|
||||||
storage_type: StorageType,
|
|
||||||
metadata: Optional[Dict[str, str]],
|
|
||||||
) -> Optional[str]:
|
|
||||||
client = get_or_create_container_client(container, storage_type, metadata)
|
|
||||||
if client is None:
|
|
||||||
return None
|
|
||||||
return get_container_sas_url_service(
|
|
||||||
client,
|
|
||||||
read=True,
|
|
||||||
write=True,
|
|
||||||
delete=True,
|
|
||||||
list_=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def delete_container(container: Container, storage_type: StorageType) -> bool:
|
|
||||||
accounts = get_accounts(storage_type)
|
|
||||||
deleted = False
|
|
||||||
for account in accounts:
|
|
||||||
service = get_blob_service(account)
|
|
||||||
try:
|
|
||||||
service.delete_container(container)
|
|
||||||
deleted = True
|
|
||||||
except ResourceNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return deleted
|
|
||||||
|
|
||||||
|
|
||||||
def sas_time_window(
|
|
||||||
duration: datetime.timedelta,
|
|
||||||
) -> Tuple[datetime.datetime, datetime.datetime]:
|
|
||||||
# SAS URLs are valid 6 hours earlier, primarily to work around dev
|
|
||||||
# workstations having out-of-sync time. Additionally, SAS URLs are stopped
|
|
||||||
# 15 minutes later than requested based on "Be careful with SAS start time"
|
|
||||||
# guidance.
|
|
||||||
# Ref: https://docs.microsoft.com/en-us/azure/storage/common/storage-sas-overview
|
|
||||||
SAS_START_TIME_DELTA = datetime.timedelta(hours=6)
|
|
||||||
SAS_END_TIME_DELTA = datetime.timedelta(minutes=15)
|
|
||||||
|
|
||||||
now = datetime.datetime.utcnow()
|
|
||||||
start = now - SAS_START_TIME_DELTA
|
|
||||||
expiry = now + duration + SAS_END_TIME_DELTA
|
|
||||||
return (start, expiry)
|
|
||||||
|
|
||||||
|
|
||||||
def get_container_sas_url_service(
|
|
||||||
client: ContainerClient,
|
|
||||||
*,
|
|
||||||
read: bool = False,
|
|
||||||
write: bool = False,
|
|
||||||
delete: bool = False,
|
|
||||||
list_: bool = False,
|
|
||||||
delete_previous_version: bool = False,
|
|
||||||
tag: bool = False,
|
|
||||||
duration: datetime.timedelta = CONTAINER_SAS_DEFAULT_DURATION,
|
|
||||||
) -> str:
|
|
||||||
account_name = client.account_name
|
|
||||||
container_name = client.container_name
|
|
||||||
account_key = get_storage_account_name_key_by_name(account_name)
|
|
||||||
|
|
||||||
start, expiry = sas_time_window(duration)
|
|
||||||
|
|
||||||
sas = generate_container_sas(
|
|
||||||
account_name,
|
|
||||||
container_name,
|
|
||||||
account_key=account_key,
|
|
||||||
permission=ContainerSasPermissions(
|
|
||||||
read=read,
|
|
||||||
write=write,
|
|
||||||
delete=delete,
|
|
||||||
list=list_,
|
|
||||||
delete_previous_version=delete_previous_version,
|
|
||||||
tag=tag,
|
|
||||||
),
|
|
||||||
start=start,
|
|
||||||
expiry=expiry,
|
|
||||||
)
|
|
||||||
|
|
||||||
with_sas = ContainerClient(
|
|
||||||
get_url(account_name),
|
|
||||||
container_name=container_name,
|
|
||||||
credential=sas,
|
|
||||||
)
|
|
||||||
return cast(str, with_sas.url)
|
|
||||||
|
|
||||||
|
|
||||||
def get_container_sas_url(
|
|
||||||
container: Container,
|
|
||||||
storage_type: StorageType,
|
|
||||||
*,
|
|
||||||
read: bool = False,
|
|
||||||
write: bool = False,
|
|
||||||
delete: bool = False,
|
|
||||||
list_: bool = False,
|
|
||||||
) -> str:
|
|
||||||
client = find_container(container, storage_type)
|
|
||||||
if not client:
|
|
||||||
raise Exception("unable to create container sas for missing container")
|
|
||||||
|
|
||||||
return get_container_sas_url_service(
|
|
||||||
client,
|
|
||||||
read=read,
|
|
||||||
write=write,
|
|
||||||
delete=delete,
|
|
||||||
list_=list_,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_file_url(container: Container, name: str, storage_type: StorageType) -> str:
|
|
||||||
client = find_container(container, storage_type)
|
|
||||||
if not client:
|
|
||||||
raise Exception("unable to find container: %s - %s" % (container, storage_type))
|
|
||||||
|
|
||||||
# get_url has a trailing '/'
|
|
||||||
return f"{get_url(client.account_name)}{container}/{name}"
|
|
||||||
|
|
||||||
|
|
||||||
def get_file_sas_url(
|
|
||||||
container: Container,
|
|
||||||
name: str,
|
|
||||||
storage_type: StorageType,
|
|
||||||
*,
|
|
||||||
read: bool = False,
|
|
||||||
add: bool = False,
|
|
||||||
create: bool = False,
|
|
||||||
write: bool = False,
|
|
||||||
delete: bool = False,
|
|
||||||
delete_previous_version: bool = False,
|
|
||||||
tag: bool = False,
|
|
||||||
duration: datetime.timedelta = CONTAINER_SAS_DEFAULT_DURATION,
|
|
||||||
) -> str:
|
|
||||||
client = find_container(container, storage_type)
|
|
||||||
if not client:
|
|
||||||
raise Exception("unable to find container: %s - %s" % (container, storage_type))
|
|
||||||
|
|
||||||
account_key = get_storage_account_name_key_by_name(client.account_name)
|
|
||||||
start, expiry = sas_time_window(duration)
|
|
||||||
|
|
||||||
permission = BlobSasPermissions(
|
|
||||||
read=read,
|
|
||||||
add=add,
|
|
||||||
create=create,
|
|
||||||
write=write,
|
|
||||||
delete=delete,
|
|
||||||
delete_previous_version=delete_previous_version,
|
|
||||||
tag=tag,
|
|
||||||
)
|
|
||||||
sas = generate_blob_sas(
|
|
||||||
client.account_name,
|
|
||||||
container,
|
|
||||||
name,
|
|
||||||
account_key=account_key,
|
|
||||||
permission=permission,
|
|
||||||
expiry=expiry,
|
|
||||||
start=start,
|
|
||||||
)
|
|
||||||
|
|
||||||
with_sas = BlobClient(
|
|
||||||
get_url(client.account_name),
|
|
||||||
container,
|
|
||||||
name,
|
|
||||||
credential=sas,
|
|
||||||
)
|
|
||||||
return cast(str, with_sas.url)
|
|
||||||
|
|
||||||
|
|
||||||
def save_blob(
|
|
||||||
container: Container,
|
|
||||||
name: str,
|
|
||||||
data: Union[str, bytes],
|
|
||||||
storage_type: StorageType,
|
|
||||||
) -> None:
|
|
||||||
client = find_container(container, storage_type)
|
|
||||||
if not client:
|
|
||||||
raise Exception("unable to find container: %s - %s" % (container, storage_type))
|
|
||||||
|
|
||||||
client.get_blob_client(name).upload_blob(data, overwrite=True)
|
|
||||||
|
|
||||||
|
|
||||||
def get_blob(
|
|
||||||
container: Container, name: str, storage_type: StorageType
|
|
||||||
) -> Optional[bytes]:
|
|
||||||
client = find_container(container, storage_type)
|
|
||||||
if not client:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
return cast(
|
|
||||||
bytes, client.get_blob_client(name).download_blob().content_as_bytes()
|
|
||||||
)
|
|
||||||
except AzureMissingResourceHttpError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def blob_exists(container: Container, name: str, storage_type: StorageType) -> bool:
|
|
||||||
client = find_container(container, storage_type)
|
|
||||||
if not client:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return cast(bool, client.get_blob_client(name).exists())
|
|
||||||
|
|
||||||
|
|
||||||
def delete_blob(container: Container, name: str, storage_type: StorageType) -> bool:
|
|
||||||
client = find_container(container, storage_type)
|
|
||||||
if not client:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
client.get_blob_client(name).delete_blob()
|
|
||||||
return True
|
|
||||||
except AzureMissingResourceHttpError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def auth_download_url(container: Container, filename: str) -> str:
|
|
||||||
instance = os.environ["ONEFUZZ_INSTANCE"]
|
|
||||||
return "%s/api/download?%s" % (
|
|
||||||
instance,
|
|
||||||
urllib.parse.urlencode({"container": container, "filename": filename}),
|
|
||||||
)
|
|
@ -1,246 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import functools
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import urllib.parse
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, TypeVar, cast
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
import requests
|
|
||||||
from azure.core.exceptions import ClientAuthenticationError
|
|
||||||
from azure.identity import DefaultAzureCredential
|
|
||||||
from azure.keyvault.secrets import SecretClient
|
|
||||||
from azure.mgmt.resource import ResourceManagementClient
|
|
||||||
from azure.mgmt.subscription import SubscriptionClient
|
|
||||||
from memoization import cached
|
|
||||||
from msrestazure.azure_active_directory import MSIAuthentication
|
|
||||||
from msrestazure.tools import parse_resource_id
|
|
||||||
from onefuzztypes.primitives import Container, Region
|
|
||||||
|
|
||||||
from .monkeypatch import allow_more_workers, reduce_logging
|
|
||||||
|
|
||||||
# https://docs.microsoft.com/en-us/graph/api/overview?view=graph-rest-1.0
|
|
||||||
GRAPH_RESOURCE = "https://graph.microsoft.com"
|
|
||||||
GRAPH_RESOURCE_ENDPOINT = "https://graph.microsoft.com/v1.0"
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def get_msi() -> MSIAuthentication:
|
|
||||||
allow_more_workers()
|
|
||||||
reduce_logging()
|
|
||||||
return MSIAuthentication()
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def get_identity() -> DefaultAzureCredential:
|
|
||||||
allow_more_workers()
|
|
||||||
reduce_logging()
|
|
||||||
return DefaultAzureCredential()
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def get_base_resource_group() -> Any: # should be str
|
|
||||||
return parse_resource_id(os.environ["ONEFUZZ_RESOURCE_GROUP"])["resource_group"]
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def get_base_region() -> Region:
|
|
||||||
client = ResourceManagementClient(
|
|
||||||
credential=get_identity(), subscription_id=get_subscription()
|
|
||||||
)
|
|
||||||
group = client.resource_groups.get(get_base_resource_group())
|
|
||||||
return Region(group.location)
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def get_subscription() -> Any: # should be str
|
|
||||||
return parse_resource_id(os.environ["ONEFUZZ_DATA_STORAGE"])["subscription"]
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def get_insights_instrumentation_key() -> Any: # should be str
|
|
||||||
return os.environ["APPINSIGHTS_INSTRUMENTATIONKEY"]
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def get_insights_appid() -> str:
|
|
||||||
return os.environ["APPINSIGHTS_APPID"]
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def get_instance_name() -> str:
|
|
||||||
return os.environ["ONEFUZZ_INSTANCE_NAME"]
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def get_instance_url() -> str:
|
|
||||||
return "https://%s.azurewebsites.net" % get_instance_name()
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def get_agent_instance_url() -> str:
|
|
||||||
return get_instance_url()
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def get_instance_id() -> UUID:
|
|
||||||
from .containers import get_blob
|
|
||||||
from .storage import StorageType
|
|
||||||
|
|
||||||
blob = get_blob(Container("base-config"), "instance_id", StorageType.config)
|
|
||||||
if blob is None:
|
|
||||||
raise Exception("missing instance_id")
|
|
||||||
return UUID(blob.decode())
|
|
||||||
|
|
||||||
|
|
||||||
DAY_IN_SECONDS = 60 * 60 * 24
|
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl=DAY_IN_SECONDS)
|
|
||||||
def get_regions() -> List[Region]:
|
|
||||||
subscription = get_subscription()
|
|
||||||
client = SubscriptionClient(credential=get_identity())
|
|
||||||
locations = client.subscriptions.list_locations(subscription)
|
|
||||||
return sorted([Region(x.name) for x in locations])
|
|
||||||
|
|
||||||
|
|
||||||
class GraphQueryError(Exception):
|
|
||||||
def __init__(self, message: str, status_code: Optional[int]) -> None:
|
|
||||||
super(GraphQueryError, self).__init__(message)
|
|
||||||
self.message = message
|
|
||||||
self.status_code = status_code
|
|
||||||
|
|
||||||
|
|
||||||
def query_microsoft_graph(
|
|
||||||
method: str,
|
|
||||||
resource: str,
|
|
||||||
params: Optional[Dict] = None,
|
|
||||||
body: Optional[Dict] = None,
|
|
||||||
) -> Dict:
|
|
||||||
cred = get_identity()
|
|
||||||
access_token = cred.get_token(f"{GRAPH_RESOURCE}/.default")
|
|
||||||
|
|
||||||
url = urllib.parse.urljoin(f"{GRAPH_RESOURCE_ENDPOINT}/", resource)
|
|
||||||
headers = {
|
|
||||||
"Authorization": "Bearer %s" % access_token.token,
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
response = requests.request(
|
|
||||||
method=method, url=url, headers=headers, params=params, json=body
|
|
||||||
)
|
|
||||||
|
|
||||||
if 200 <= response.status_code < 300:
|
|
||||||
if response.content and response.content.strip():
|
|
||||||
json = response.json()
|
|
||||||
if isinstance(json, Dict):
|
|
||||||
return json
|
|
||||||
else:
|
|
||||||
raise GraphQueryError(
|
|
||||||
"invalid data expected a json object: HTTP"
|
|
||||||
f" {response.status_code} - {json}",
|
|
||||||
response.status_code,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return {}
|
|
||||||
else:
|
|
||||||
error_text = str(response.content, encoding="utf-8", errors="backslashreplace")
|
|
||||||
raise GraphQueryError(
|
|
||||||
f"request did not succeed: HTTP {response.status_code} - {error_text}",
|
|
||||||
response.status_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def query_microsoft_graph_list(
|
|
||||||
method: str,
|
|
||||||
resource: str,
|
|
||||||
params: Optional[Dict] = None,
|
|
||||||
body: Optional[Dict] = None,
|
|
||||||
) -> List[Any]:
|
|
||||||
result = query_microsoft_graph(
|
|
||||||
method,
|
|
||||||
resource,
|
|
||||||
params,
|
|
||||||
body,
|
|
||||||
)
|
|
||||||
value = result.get("value")
|
|
||||||
if isinstance(value, list):
|
|
||||||
return value
|
|
||||||
else:
|
|
||||||
raise GraphQueryError("Expected data containing a list of values", None)
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def get_scaleset_identity_resource_path() -> str:
|
|
||||||
scaleset_id_name = "%s-scalesetid" % get_instance_name()
|
|
||||||
resource_group_path = "/subscriptions/%s/resourceGroups/%s/providers" % (
|
|
||||||
get_subscription(),
|
|
||||||
get_base_resource_group(),
|
|
||||||
)
|
|
||||||
return "%s/Microsoft.ManagedIdentity/userAssignedIdentities/%s" % (
|
|
||||||
resource_group_path,
|
|
||||||
scaleset_id_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def get_scaleset_principal_id() -> UUID:
|
|
||||||
api_version = "2018-11-30" # matches the apiversion in the deployment template
|
|
||||||
client = ResourceManagementClient(
|
|
||||||
credential=get_identity(), subscription_id=get_subscription()
|
|
||||||
)
|
|
||||||
uid = client.resources.get_by_id(get_scaleset_identity_resource_path(), api_version)
|
|
||||||
return UUID(uid.properties["principalId"])
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def get_keyvault_client(vault_url: str) -> SecretClient:
|
|
||||||
return SecretClient(vault_url=vault_url, credential=DefaultAzureCredential())
|
|
||||||
|
|
||||||
|
|
||||||
def clear_azure_client_cache() -> None:
|
|
||||||
# clears the memoization of the Azure clients.
|
|
||||||
|
|
||||||
from .compute import get_compute_client
|
|
||||||
from .containers import get_blob_service
|
|
||||||
from .network_mgmt_client import get_network_client
|
|
||||||
from .storage import get_mgmt_client
|
|
||||||
|
|
||||||
# currently memoization.cache does not project the wrapped function's types.
|
|
||||||
# As a workaround, CI comments out the `cached` wrapper, then runs the type
|
|
||||||
# validation. This enables calling the wrapper's clear_cache if it's not
|
|
||||||
# disabled.
|
|
||||||
for func in [
|
|
||||||
get_msi,
|
|
||||||
get_identity,
|
|
||||||
get_compute_client,
|
|
||||||
get_blob_service,
|
|
||||||
get_network_client,
|
|
||||||
get_mgmt_client,
|
|
||||||
]:
|
|
||||||
clear_func = getattr(func, "clear_cache", None)
|
|
||||||
if clear_func is not None:
|
|
||||||
clear_func()
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=Callable[..., Any])
|
|
||||||
|
|
||||||
|
|
||||||
class retry_on_auth_failure:
|
|
||||||
def __call__(self, func: T) -> T:
|
|
||||||
@functools.wraps(func)
|
|
||||||
def decorated(*args, **kwargs): # type: ignore
|
|
||||||
try:
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
except ClientAuthenticationError as err:
|
|
||||||
logging.warning(
|
|
||||||
"clearing authentication cache after auth failure: %s", err
|
|
||||||
)
|
|
||||||
|
|
||||||
clear_azure_client_cache()
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
return cast(T, decorated)
|
|
@ -1,29 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from azure.core.exceptions import ResourceNotFoundError
|
|
||||||
from msrestazure.azure_exceptions import CloudError
|
|
||||||
|
|
||||||
from .compute import get_compute_client
|
|
||||||
|
|
||||||
|
|
||||||
def list_disks(resource_group: str) -> Any:
|
|
||||||
logging.info("listing disks %s", resource_group)
|
|
||||||
compute_client = get_compute_client()
|
|
||||||
return compute_client.disks.list_by_resource_group(resource_group)
|
|
||||||
|
|
||||||
|
|
||||||
def delete_disk(resource_group: str, name: str) -> bool:
|
|
||||||
logging.info("deleting disks %s : %s", resource_group, name)
|
|
||||||
compute_client = get_compute_client()
|
|
||||||
try:
|
|
||||||
compute_client.disks.begin_delete(resource_group, name)
|
|
||||||
return True
|
|
||||||
except (ResourceNotFoundError, CloudError) as err:
|
|
||||||
logging.error("unable to delete disk: %s", err)
|
|
||||||
return False
|
|
@ -1,51 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
from typing import Dict, List, Protocol
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from ..config import InstanceConfig
|
|
||||||
from .creds import query_microsoft_graph_list
|
|
||||||
|
|
||||||
|
|
||||||
class GroupMembershipChecker(Protocol):
|
|
||||||
def is_member(self, group_ids: List[UUID], member_id: UUID) -> bool:
|
|
||||||
"""Check if member is part of at least one of the groups"""
|
|
||||||
if member_id in group_ids:
|
|
||||||
return True
|
|
||||||
|
|
||||||
groups = self.get_groups(member_id)
|
|
||||||
for g in group_ids:
|
|
||||||
if g in groups:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_groups(self, member_id: UUID) -> List[UUID]:
|
|
||||||
"""Gets all the groups of the provided member"""
|
|
||||||
|
|
||||||
|
|
||||||
def create_group_membership_checker() -> GroupMembershipChecker:
|
|
||||||
config = InstanceConfig.fetch()
|
|
||||||
if config.group_membership:
|
|
||||||
return StaticGroupMembership(config.group_membership)
|
|
||||||
else:
|
|
||||||
return AzureADGroupMembership()
|
|
||||||
|
|
||||||
|
|
||||||
class AzureADGroupMembership(GroupMembershipChecker):
|
|
||||||
def get_groups(self, member_id: UUID) -> List[UUID]:
|
|
||||||
response = query_microsoft_graph_list(
|
|
||||||
method="GET", resource=f"users/{member_id}/transitiveMemberOf"
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
class StaticGroupMembership(GroupMembershipChecker):
|
|
||||||
def __init__(self, memberships: Dict[str, List[UUID]]):
|
|
||||||
self.memberships = memberships
|
|
||||||
|
|
||||||
def get_groups(self, member_id: UUID) -> List[UUID]:
|
|
||||||
return self.memberships.get(str(member_id), [])
|
|
@ -1,55 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from azure.core.exceptions import ResourceNotFoundError
|
|
||||||
from memoization import cached
|
|
||||||
from msrestazure.azure_exceptions import CloudError
|
|
||||||
from msrestazure.tools import parse_resource_id
|
|
||||||
from onefuzztypes.enums import OS, ErrorCode
|
|
||||||
from onefuzztypes.models import Error
|
|
||||||
from onefuzztypes.primitives import Region
|
|
||||||
|
|
||||||
from .compute import get_compute_client
|
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl=60)
|
|
||||||
def get_os(region: Region, image: str) -> Union[Error, OS]:
|
|
||||||
client = get_compute_client()
|
|
||||||
# The dict returned here may not have any defined keys.
|
|
||||||
#
|
|
||||||
# See: https://github.com/Azure/msrestazure-for-python/blob/v0.6.3/msrestazure/tools.py#L134 # noqa: E501
|
|
||||||
parsed = parse_resource_id(image)
|
|
||||||
if "resource_group" in parsed:
|
|
||||||
if parsed["type"] == "galleries":
|
|
||||||
try:
|
|
||||||
# See: https://docs.microsoft.com/en-us/rest/api/compute/gallery-images/get#galleryimage # noqa: E501
|
|
||||||
name = client.gallery_images.get(
|
|
||||||
parsed["resource_group"], parsed["name"], parsed["child_name_1"]
|
|
||||||
).os_type.lower()
|
|
||||||
except (ResourceNotFoundError, CloudError) as err:
|
|
||||||
return Error(code=ErrorCode.INVALID_IMAGE, errors=[str(err)])
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
# See: https://docs.microsoft.com/en-us/rest/api/compute/images/get
|
|
||||||
name = client.images.get(
|
|
||||||
parsed["resource_group"], parsed["name"]
|
|
||||||
).storage_profile.os_disk.os_type.lower()
|
|
||||||
except (ResourceNotFoundError, CloudError) as err:
|
|
||||||
return Error(code=ErrorCode.INVALID_IMAGE, errors=[str(err)])
|
|
||||||
else:
|
|
||||||
publisher, offer, sku, version = image.split(":")
|
|
||||||
try:
|
|
||||||
if version == "latest":
|
|
||||||
version = client.virtual_machine_images.list(
|
|
||||||
region, publisher, offer, sku, top=1
|
|
||||||
)[0].name
|
|
||||||
name = client.virtual_machine_images.get(
|
|
||||||
region, publisher, offer, sku, version
|
|
||||||
).os_disk_image.operating_system.lower()
|
|
||||||
except (ResourceNotFoundError, CloudError) as err:
|
|
||||||
return Error(code=ErrorCode.INVALID_IMAGE, errors=[str(err)])
|
|
||||||
return OS[name]
|
|
@ -1,168 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from typing import Any, Dict, Optional, Union
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from azure.core.exceptions import ResourceNotFoundError
|
|
||||||
from azure.mgmt.network.models import Subnet
|
|
||||||
from msrestazure.azure_exceptions import CloudError
|
|
||||||
from msrestazure.tools import parse_resource_id
|
|
||||||
from onefuzztypes.enums import ErrorCode
|
|
||||||
from onefuzztypes.models import Error
|
|
||||||
from onefuzztypes.primitives import Region
|
|
||||||
|
|
||||||
from .creds import get_base_resource_group
|
|
||||||
from .network import Network
|
|
||||||
from .network_mgmt_client import get_network_client
|
|
||||||
from .nsg import NSG
|
|
||||||
from .vmss import get_instance_id
|
|
||||||
|
|
||||||
|
|
||||||
def get_scaleset_instance_ip(scaleset: UUID, machine_id: UUID) -> Optional[str]:
|
|
||||||
instance = get_instance_id(scaleset, machine_id)
|
|
||||||
if not isinstance(instance, str):
|
|
||||||
return None
|
|
||||||
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
|
|
||||||
client = get_network_client()
|
|
||||||
intf = client.network_interfaces.list_virtual_machine_scale_set_network_interfaces(
|
|
||||||
resource_group, str(scaleset)
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
for interface in intf:
|
|
||||||
resource = parse_resource_id(interface.virtual_machine.id)
|
|
||||||
if resource.get("resource_name") != instance:
|
|
||||||
continue
|
|
||||||
|
|
||||||
for config in interface.ip_configurations:
|
|
||||||
if config.private_ip_address is None:
|
|
||||||
continue
|
|
||||||
return str(config.private_ip_address)
|
|
||||||
except (ResourceNotFoundError, CloudError):
|
|
||||||
# this can fail if an interface is removed during the iteration
|
|
||||||
pass
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_ip(resource_group: str, name: str) -> Optional[Any]:
|
|
||||||
logging.info("getting ip %s:%s", resource_group, name)
|
|
||||||
network_client = get_network_client()
|
|
||||||
try:
|
|
||||||
return network_client.public_ip_addresses.get(resource_group, name)
|
|
||||||
except (ResourceNotFoundError, CloudError):
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def delete_ip(resource_group: str, name: str) -> Any:
|
|
||||||
logging.info("deleting ip %s:%s", resource_group, name)
|
|
||||||
network_client = get_network_client()
|
|
||||||
return network_client.public_ip_addresses.begin_delete(resource_group, name)
|
|
||||||
|
|
||||||
|
|
||||||
def create_ip(resource_group: str, name: str, region: Region) -> Any:
|
|
||||||
logging.info("creating ip for %s:%s in %s", resource_group, name, region)
|
|
||||||
|
|
||||||
network_client = get_network_client()
|
|
||||||
params: Dict[str, Union[str, Dict[str, str]]] = {
|
|
||||||
"location": region,
|
|
||||||
"public_ip_allocation_method": "Dynamic",
|
|
||||||
}
|
|
||||||
if "ONEFUZZ_OWNER" in os.environ:
|
|
||||||
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
|
|
||||||
return network_client.public_ip_addresses.begin_create_or_update(
|
|
||||||
resource_group, name, params
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_public_nic(resource_group: str, name: str) -> Optional[Any]:
|
|
||||||
logging.info("getting nic: %s %s", resource_group, name)
|
|
||||||
network_client = get_network_client()
|
|
||||||
try:
|
|
||||||
return network_client.network_interfaces.get(resource_group, name)
|
|
||||||
except (ResourceNotFoundError, CloudError):
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def delete_nic(resource_group: str, name: str) -> Optional[Any]:
|
|
||||||
logging.info("deleting nic %s:%s", resource_group, name)
|
|
||||||
network_client = get_network_client()
|
|
||||||
return network_client.network_interfaces.begin_delete(resource_group, name)
|
|
||||||
|
|
||||||
|
|
||||||
def create_public_nic(
|
|
||||||
resource_group: str, name: str, region: Region, nsg: Optional[NSG]
|
|
||||||
) -> Optional[Error]:
|
|
||||||
logging.info("creating nic for %s:%s in %s", resource_group, name, region)
|
|
||||||
|
|
||||||
network = Network(region)
|
|
||||||
subnet_id = network.get_id()
|
|
||||||
if subnet_id is None:
|
|
||||||
network.create()
|
|
||||||
return None
|
|
||||||
|
|
||||||
if nsg:
|
|
||||||
subnet = network.get_subnet()
|
|
||||||
if isinstance(subnet, Subnet) and not subnet.network_security_group:
|
|
||||||
result = nsg.associate_subnet(network.get_vnet(), subnet)
|
|
||||||
if isinstance(result, Error):
|
|
||||||
return result
|
|
||||||
return None
|
|
||||||
|
|
||||||
ip = get_ip(resource_group, name)
|
|
||||||
if not ip:
|
|
||||||
create_ip(resource_group, name, region)
|
|
||||||
return None
|
|
||||||
|
|
||||||
params = {
|
|
||||||
"location": region,
|
|
||||||
"ip_configurations": [
|
|
||||||
{
|
|
||||||
"name": "myIPConfig",
|
|
||||||
"public_ip_address": ip,
|
|
||||||
"subnet": {"id": subnet_id},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
if "ONEFUZZ_OWNER" in os.environ:
|
|
||||||
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
|
|
||||||
|
|
||||||
network_client = get_network_client()
|
|
||||||
try:
|
|
||||||
network_client.network_interfaces.begin_create_or_update(
|
|
||||||
resource_group, name, params
|
|
||||||
)
|
|
||||||
except (ResourceNotFoundError, CloudError) as err:
|
|
||||||
if "RetryableError" not in repr(err):
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.VM_CREATE_FAILED,
|
|
||||||
errors=["unable to create nic: %s" % err],
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_public_ip(resource_id: str) -> Optional[str]:
|
|
||||||
logging.info("getting ip for %s", resource_id)
|
|
||||||
network_client = get_network_client()
|
|
||||||
resource = parse_resource_id(resource_id)
|
|
||||||
ip = (
|
|
||||||
network_client.network_interfaces.get(
|
|
||||||
resource["resource_group"], resource["name"]
|
|
||||||
)
|
|
||||||
.ip_configurations[0]
|
|
||||||
.public_ip_address
|
|
||||||
)
|
|
||||||
resource = parse_resource_id(ip.id)
|
|
||||||
ip = network_client.public_ip_addresses.get(
|
|
||||||
resource["resource_group"], resource["name"]
|
|
||||||
).ip_address
|
|
||||||
if ip is None:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return str(ip)
|
|
@ -1,43 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from azure.mgmt.loganalytics import LogAnalyticsManagementClient
|
|
||||||
from memoization import cached
|
|
||||||
|
|
||||||
from .creds import get_base_resource_group, get_identity, get_subscription
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def get_monitor_client() -> LogAnalyticsManagementClient:
|
|
||||||
return LogAnalyticsManagementClient(get_identity(), get_subscription())
|
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl=60)
|
|
||||||
def get_monitor_settings() -> Dict[str, str]:
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
workspace_name = os.environ["ONEFUZZ_MONITOR"]
|
|
||||||
client = get_monitor_client()
|
|
||||||
customer_id = client.workspaces.get(resource_group, workspace_name).customer_id
|
|
||||||
shared_key = client.shared_keys.get_shared_keys(
|
|
||||||
resource_group, workspace_name
|
|
||||||
).primary_shared_key
|
|
||||||
return {"id": customer_id, "key": shared_key}
|
|
||||||
|
|
||||||
|
|
||||||
def get_workspace_id() -> str:
|
|
||||||
# TODO:
|
|
||||||
# Once #1679 merges, we can use ONEFUZZ_MONITOR instead of ONEFUZZ_INSTANCE_NAME
|
|
||||||
workspace_id = (
|
|
||||||
"/subscriptions/%s/resourceGroups/%s/providers/microsoft.operationalinsights/workspaces/%s" # noqa: E501
|
|
||||||
% (
|
|
||||||
get_subscription(),
|
|
||||||
get_base_resource_group(),
|
|
||||||
os.environ["ONEFUZZ_INSTANCE_NAME"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return workspace_id
|
|
@ -1,9 +0,0 @@
|
|||||||
from azure.mgmt.monitor import MonitorManagementClient
|
|
||||||
from memoization import cached
|
|
||||||
|
|
||||||
from .creds import get_identity, get_subscription
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def get_monitor_client() -> MonitorManagementClient:
|
|
||||||
return MonitorManagementClient(get_identity(), get_subscription())
|
|
@ -1,56 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import inspect
|
|
||||||
import logging
|
|
||||||
|
|
||||||
WORKERS_DONE = False
|
|
||||||
REDUCE_LOGGING = False
|
|
||||||
|
|
||||||
|
|
||||||
def allow_more_workers() -> None:
|
|
||||||
global WORKERS_DONE
|
|
||||||
if WORKERS_DONE:
|
|
||||||
return
|
|
||||||
|
|
||||||
stack = inspect.stack()
|
|
||||||
for entry in stack:
|
|
||||||
if entry.filename.endswith("azure_functions_worker/dispatcher.py"):
|
|
||||||
if entry.frame.f_locals["self"]._sync_call_tp._max_workers == 1:
|
|
||||||
logging.info("bumped thread worker count to 50")
|
|
||||||
entry.frame.f_locals["self"]._sync_call_tp._max_workers = 50
|
|
||||||
|
|
||||||
WORKERS_DONE = True
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Replace this with a better method for filtering out logging
|
|
||||||
# See https://github.com/Azure/azure-functions-python-worker/issues/743
|
|
||||||
def reduce_logging() -> None:
|
|
||||||
global REDUCE_LOGGING
|
|
||||||
if REDUCE_LOGGING:
|
|
||||||
return
|
|
||||||
|
|
||||||
to_quiet = [
|
|
||||||
"azure",
|
|
||||||
"cli",
|
|
||||||
"grpc",
|
|
||||||
"concurrent",
|
|
||||||
"oauthlib",
|
|
||||||
"msrest",
|
|
||||||
"opencensus",
|
|
||||||
"urllib3",
|
|
||||||
"requests",
|
|
||||||
"aiohttp",
|
|
||||||
"asyncio",
|
|
||||||
"adal-python",
|
|
||||||
]
|
|
||||||
|
|
||||||
for name in logging.Logger.manager.loggerDict:
|
|
||||||
logger = logging.getLogger(name)
|
|
||||||
for prefix in to_quiet:
|
|
||||||
if logger.name.startswith(prefix):
|
|
||||||
logger.level = logging.WARN
|
|
||||||
|
|
||||||
REDUCE_LOGGING = True
|
|
@ -1,78 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
from azure.mgmt.network.models import Subnet, VirtualNetwork
|
|
||||||
from msrestazure.azure_exceptions import CloudError
|
|
||||||
from onefuzztypes.enums import ErrorCode
|
|
||||||
from onefuzztypes.models import Error, NetworkConfig
|
|
||||||
from onefuzztypes.primitives import Region
|
|
||||||
|
|
||||||
from ..config import InstanceConfig
|
|
||||||
from .creds import get_base_resource_group
|
|
||||||
from .subnet import create_virtual_network, get_subnet, get_subnet_id, get_vnet
|
|
||||||
|
|
||||||
# This was generated randomly and should be preserved moving forwards
|
|
||||||
NETWORK_GUID_NAMESPACE = uuid.UUID("372977ad-b533-416a-b1b4-f770898e0b11")
|
|
||||||
|
|
||||||
|
|
||||||
class Network:
|
|
||||||
def __init__(self, region: Region):
|
|
||||||
self.group = get_base_resource_group()
|
|
||||||
self.region = region
|
|
||||||
self.network_config = InstanceConfig.fetch().network_config
|
|
||||||
|
|
||||||
# Network names will be calculated from the address_space/subnet
|
|
||||||
# *except* if they are the original values. This allows backwards
|
|
||||||
# compatibility to existing configs if you don't change the network
|
|
||||||
# configs.
|
|
||||||
if (
|
|
||||||
self.network_config.address_space
|
|
||||||
== NetworkConfig.__fields__["address_space"].default
|
|
||||||
and self.network_config.subnet == NetworkConfig.__fields__["subnet"].default
|
|
||||||
):
|
|
||||||
self.name: str = self.region
|
|
||||||
else:
|
|
||||||
network_id = uuid.uuid5(
|
|
||||||
NETWORK_GUID_NAMESPACE,
|
|
||||||
"|".join(
|
|
||||||
[self.network_config.address_space, self.network_config.subnet]
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self.name = f"{self.region}-{network_id}"
|
|
||||||
|
|
||||||
def exists(self) -> bool:
|
|
||||||
return self.get_id() is not None
|
|
||||||
|
|
||||||
def get_id(self) -> Optional[str]:
|
|
||||||
return get_subnet_id(self.group, self.name, self.name)
|
|
||||||
|
|
||||||
def get_subnet(self) -> Optional[Subnet]:
|
|
||||||
return get_subnet(self.group, self.name, self.name)
|
|
||||||
|
|
||||||
def get_vnet(self) -> Optional[VirtualNetwork]:
|
|
||||||
return get_vnet(self.group, self.name)
|
|
||||||
|
|
||||||
def create(self) -> Union[None, Error]:
|
|
||||||
if not self.exists():
|
|
||||||
result = create_virtual_network(
|
|
||||||
self.group, self.name, self.region, self.network_config
|
|
||||||
)
|
|
||||||
if isinstance(result, CloudError):
|
|
||||||
error = Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_CREATE_NETWORK, errors=[result.message]
|
|
||||||
)
|
|
||||||
logging.error(
|
|
||||||
"network creation failed: %s:%s- %s",
|
|
||||||
self.name,
|
|
||||||
self.region,
|
|
||||||
error,
|
|
||||||
)
|
|
||||||
return error
|
|
||||||
|
|
||||||
return None
|
|
@ -1,9 +0,0 @@
|
|||||||
from azure.mgmt.network import NetworkManagementClient
|
|
||||||
from memoization import cached
|
|
||||||
|
|
||||||
from .creds import get_identity, get_subscription
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def get_network_client() -> NetworkManagementClient:
|
|
||||||
return NetworkManagementClient(get_identity(), get_subscription())
|
|
@ -1,489 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from typing import Dict, List, Optional, Set, Union, cast
|
|
||||||
|
|
||||||
from azure.core.exceptions import HttpResponseError, ResourceNotFoundError
|
|
||||||
from azure.mgmt.network.models import (
|
|
||||||
NetworkInterface,
|
|
||||||
NetworkSecurityGroup,
|
|
||||||
SecurityRule,
|
|
||||||
SecurityRuleAccess,
|
|
||||||
Subnet,
|
|
||||||
VirtualNetwork,
|
|
||||||
)
|
|
||||||
from msrestazure.azure_exceptions import CloudError
|
|
||||||
from onefuzztypes.enums import ErrorCode
|
|
||||||
from onefuzztypes.models import Error, NetworkSecurityGroupConfig
|
|
||||||
from onefuzztypes.primitives import Region
|
|
||||||
from pydantic import BaseModel, validator
|
|
||||||
|
|
||||||
from .creds import get_base_resource_group
|
|
||||||
from .network_mgmt_client import get_network_client
|
|
||||||
|
|
||||||
|
|
||||||
def is_concurrent_request_error(err: str) -> bool:
|
|
||||||
return "The request failed due to conflict with a concurrent request" in str(err)
|
|
||||||
|
|
||||||
|
|
||||||
def get_nsg(name: str) -> Optional[NetworkSecurityGroup]:
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
|
|
||||||
logging.debug("getting nsg: %s", name)
|
|
||||||
network_client = get_network_client()
|
|
||||||
try:
|
|
||||||
nsg = network_client.network_security_groups.get(resource_group, name)
|
|
||||||
return cast(NetworkSecurityGroup, nsg)
|
|
||||||
except (ResourceNotFoundError, CloudError) as err:
|
|
||||||
logging.debug("nsg %s does not exist: %s", name, err)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def create_nsg(name: str, location: Region) -> Union[None, Error]:
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
|
|
||||||
logging.info("creating nsg %s:%s:%s", resource_group, location, name)
|
|
||||||
network_client = get_network_client()
|
|
||||||
|
|
||||||
params: Dict = {
|
|
||||||
"location": location,
|
|
||||||
}
|
|
||||||
|
|
||||||
if "ONEFUZZ_OWNER" in os.environ:
|
|
||||||
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
|
|
||||||
|
|
||||||
try:
|
|
||||||
network_client.network_security_groups.begin_create_or_update(
|
|
||||||
resource_group, name, params
|
|
||||||
)
|
|
||||||
except (ResourceNotFoundError, CloudError) as err:
|
|
||||||
if is_concurrent_request_error(str(err)):
|
|
||||||
logging.debug(
|
|
||||||
"create NSG had conflicts with concurrent request, ignoring %s", err
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_CREATE,
|
|
||||||
errors=["Unable to create nsg %s due to %s" % (name, err)],
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def list_nsgs() -> List[NetworkSecurityGroup]:
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
network_client = get_network_client()
|
|
||||||
return list(network_client.network_security_groups.list(resource_group))
|
|
||||||
|
|
||||||
|
|
||||||
def update_nsg(nsg: NetworkSecurityGroup) -> Union[None, Error]:
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
|
|
||||||
logging.info("updating nsg %s:%s:%s", resource_group, nsg.location, nsg.name)
|
|
||||||
network_client = get_network_client()
|
|
||||||
|
|
||||||
try:
|
|
||||||
network_client.network_security_groups.begin_create_or_update(
|
|
||||||
resource_group, nsg.name, nsg
|
|
||||||
)
|
|
||||||
except (ResourceNotFoundError, CloudError) as err:
|
|
||||||
if is_concurrent_request_error(str(err)):
|
|
||||||
logging.debug(
|
|
||||||
"create NSG had conflicts with concurrent request, ignoring %s", err
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_CREATE,
|
|
||||||
errors=["Unable to update nsg %s due to %s" % (nsg.name, err)],
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# Return True if NSG is created using OneFuzz naming convention.
|
|
||||||
# Therefore NSG belongs to OneFuzz.
|
|
||||||
def ok_to_delete(active_regions: Set[Region], nsg_region: str, nsg_name: str) -> bool:
|
|
||||||
return nsg_region not in active_regions and nsg_region == nsg_name
|
|
||||||
|
|
||||||
|
|
||||||
def is_onefuzz_nsg(nsg_region: str, nsg_name: str) -> bool:
|
|
||||||
return nsg_region == nsg_name
|
|
||||||
|
|
||||||
|
|
||||||
# Returns True if deletion completed (thus resource not found) or successfully started.
|
|
||||||
# Returns False if failed to start deletion.
|
|
||||||
def start_delete_nsg(name: str) -> bool:
|
|
||||||
# NSG can be only deleted if no other resource is associated with it
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
|
|
||||||
logging.info("deleting nsg: %s %s", resource_group, name)
|
|
||||||
network_client = get_network_client()
|
|
||||||
|
|
||||||
try:
|
|
||||||
network_client.network_security_groups.begin_delete(resource_group, name)
|
|
||||||
return True
|
|
||||||
except HttpResponseError as err:
|
|
||||||
err_str = str(err)
|
|
||||||
if (
|
|
||||||
"cannot be deleted because it is in use by the following resources"
|
|
||||||
) in err_str:
|
|
||||||
return False
|
|
||||||
except ResourceNotFoundError:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def set_allowed(name: str, sources: NetworkSecurityGroupConfig) -> Union[None, Error]:
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
nsg = get_nsg(name)
|
|
||||||
if not nsg:
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_FIND,
|
|
||||||
errors=["cannot update nsg rules. nsg %s not found" % name],
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
"setting allowed incoming connection sources for nsg: %s %s",
|
|
||||||
resource_group,
|
|
||||||
name,
|
|
||||||
)
|
|
||||||
all_sources = sources.allowed_ips + sources.allowed_service_tags
|
|
||||||
security_rules = []
|
|
||||||
# NSG security rule priority range defined here:
|
|
||||||
# https://docs.microsoft.com/en-us/azure/virtual-network/network-security-groups-overview
|
|
||||||
min_priority = 100
|
|
||||||
# NSG rules per NSG limits:
|
|
||||||
# https://docs.microsoft.com/en-us/azure/azure-resource-manager/management/azure-subscription-service-limits?toc=/azure/virtual-network/toc.json#networking-limits
|
|
||||||
max_rule_count = 1000
|
|
||||||
if len(all_sources) > max_rule_count:
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.INVALID_REQUEST,
|
|
||||||
errors=[
|
|
||||||
"too many rules provided %d. Max allowed: %d"
|
|
||||||
% ((len(all_sources)), max_rule_count),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
priority = min_priority
|
|
||||||
for src in all_sources:
|
|
||||||
security_rules.append(
|
|
||||||
SecurityRule(
|
|
||||||
name="Allow" + str(priority),
|
|
||||||
protocol="*",
|
|
||||||
source_port_range="*",
|
|
||||||
destination_port_range="*",
|
|
||||||
source_address_prefix=src,
|
|
||||||
destination_address_prefix="*",
|
|
||||||
access=SecurityRuleAccess.ALLOW,
|
|
||||||
priority=priority, # between 100 and 4096
|
|
||||||
direction="Inbound",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# Will not exceed `max_rule_count` or max NSG priority (4096)
|
|
||||||
# due to earlier check of `len(all_sources)`.
|
|
||||||
priority += 1
|
|
||||||
|
|
||||||
nsg.security_rules = security_rules
|
|
||||||
return update_nsg(nsg)
|
|
||||||
|
|
||||||
|
|
||||||
def clear_all_rules(name: str) -> Union[None, Error]:
|
|
||||||
return set_allowed(name, NetworkSecurityGroupConfig())
|
|
||||||
|
|
||||||
|
|
||||||
def get_all_rules(name: str) -> Union[Error, List[SecurityRule]]:
|
|
||||||
nsg = get_nsg(name)
|
|
||||||
if not nsg:
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_FIND,
|
|
||||||
errors=["cannot get nsg rules. nsg %s not found" % name],
|
|
||||||
)
|
|
||||||
|
|
||||||
return cast(List[SecurityRule], nsg.security_rules)
|
|
||||||
|
|
||||||
|
|
||||||
def associate_nic(name: str, nic: NetworkInterface) -> Union[None, Error]:
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
nsg = get_nsg(name)
|
|
||||||
if not nsg:
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_FIND,
|
|
||||||
errors=["cannot associate nic. nsg %s not found" % name],
|
|
||||||
)
|
|
||||||
|
|
||||||
if nsg.location != nic.location:
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_UPDATE,
|
|
||||||
errors=[
|
|
||||||
"network interface and nsg have to be in the same region.",
|
|
||||||
"nsg %s %s, nic: %s %s"
|
|
||||||
% (nsg.name, nsg.location, nic.name, nic.location),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
if nic.network_security_group and nic.network_security_group.id == nsg.id:
|
|
||||||
logging.info(
|
|
||||||
"NIC %s and NSG %s already associated, not updating", nic.name, name
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
logging.info("associating nic %s with nsg: %s %s", nic.name, resource_group, name)
|
|
||||||
|
|
||||||
nic.network_security_group = nsg
|
|
||||||
network_client = get_network_client()
|
|
||||||
try:
|
|
||||||
network_client.network_interfaces.begin_create_or_update(
|
|
||||||
resource_group, nic.name, nic
|
|
||||||
)
|
|
||||||
except (ResourceNotFoundError, CloudError) as err:
|
|
||||||
if is_concurrent_request_error(str(err)):
|
|
||||||
logging.debug(
|
|
||||||
"associate NSG with NIC had conflicts",
|
|
||||||
"with concurrent request, ignoring %s",
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_UPDATE,
|
|
||||||
errors=[
|
|
||||||
"Unable to associate nsg %s with nic %s due to %s"
|
|
||||||
% (
|
|
||||||
name,
|
|
||||||
nic.name,
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def dissociate_nic(name: str, nic: NetworkInterface) -> Union[None, Error]:
|
|
||||||
if nic.network_security_group is None:
|
|
||||||
return None
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
nsg = get_nsg(name)
|
|
||||||
if not nsg:
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_FIND,
|
|
||||||
errors=["cannot update nsg rules. nsg %s not found" % name],
|
|
||||||
)
|
|
||||||
if nsg.id != nic.network_security_group.id:
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_UPDATE,
|
|
||||||
errors=[
|
|
||||||
"network interface is not associated with this nsg.",
|
|
||||||
"nsg %s, nic: %s, nic.nsg: %s"
|
|
||||||
% (
|
|
||||||
nsg.id,
|
|
||||||
nic.name,
|
|
||||||
nic.network_security_group.id,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("dissociating nic %s with nsg: %s %s", nic.name, resource_group, name)
|
|
||||||
|
|
||||||
nic.network_security_group = None
|
|
||||||
network_client = get_network_client()
|
|
||||||
try:
|
|
||||||
network_client.network_interfaces.begin_create_or_update(
|
|
||||||
resource_group, nic.name, nic
|
|
||||||
)
|
|
||||||
except (ResourceNotFoundError, CloudError) as err:
|
|
||||||
if is_concurrent_request_error(str(err)):
|
|
||||||
logging.debug(
|
|
||||||
"dissociate nsg with nic had conflicts with ",
|
|
||||||
"concurrent request, ignoring %s",
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_UPDATE,
|
|
||||||
errors=[
|
|
||||||
"Unable to dissociate nsg %s with nic %s due to %s"
|
|
||||||
% (
|
|
||||||
name,
|
|
||||||
nic.name,
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def associate_subnet(
|
|
||||||
name: str, vnet: VirtualNetwork, subnet: Subnet
|
|
||||||
) -> Union[None, Error]:
|
|
||||||
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
nsg = get_nsg(name)
|
|
||||||
if not nsg:
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_FIND,
|
|
||||||
errors=["cannot associate subnet. nsg %s not found" % name],
|
|
||||||
)
|
|
||||||
|
|
||||||
if nsg.location != vnet.location:
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_UPDATE,
|
|
||||||
errors=[
|
|
||||||
"subnet and nsg have to be in the same region.",
|
|
||||||
"nsg %s %s, subnet: %s %s"
|
|
||||||
% (nsg.name, nsg.location, subnet.name, subnet.location),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
if subnet.network_security_group and subnet.network_security_group.id == nsg.id:
|
|
||||||
logging.info(
|
|
||||||
"Subnet %s and NSG %s already associated, not updating", subnet.name, name
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
"associating subnet %s with nsg: %s %s", subnet.name, resource_group, name
|
|
||||||
)
|
|
||||||
|
|
||||||
subnet.network_security_group = nsg
|
|
||||||
network_client = get_network_client()
|
|
||||||
try:
|
|
||||||
network_client.subnets.begin_create_or_update(
|
|
||||||
resource_group, vnet.name, subnet.name, subnet
|
|
||||||
)
|
|
||||||
except (ResourceNotFoundError, CloudError) as err:
|
|
||||||
if is_concurrent_request_error(str(err)):
|
|
||||||
logging.debug(
|
|
||||||
"associate NSG with subnet had conflicts",
|
|
||||||
"with concurrent request, ignoring %s",
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_UPDATE,
|
|
||||||
errors=[
|
|
||||||
"Unable to associate nsg %s with subnet %s due to %s"
|
|
||||||
% (
|
|
||||||
name,
|
|
||||||
subnet.name,
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def dissociate_subnet(
|
|
||||||
name: str, vnet: VirtualNetwork, subnet: Subnet
|
|
||||||
) -> Union[None, Error]:
|
|
||||||
if subnet.network_security_group is None:
|
|
||||||
return None
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
nsg = get_nsg(name)
|
|
||||||
if not nsg:
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_FIND,
|
|
||||||
errors=["cannot update nsg rules. nsg %s not found" % name],
|
|
||||||
)
|
|
||||||
if nsg.id != subnet.network_security_group.id:
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_UPDATE,
|
|
||||||
errors=[
|
|
||||||
"subnet is not associated with this nsg.",
|
|
||||||
"nsg %s, subnet: %s, subnet.nsg: %s"
|
|
||||||
% (
|
|
||||||
nsg.id,
|
|
||||||
subnet.name,
|
|
||||||
subnet.network_security_group.id,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
"dissociating subnet %s with nsg: %s %s", subnet.name, resource_group, name
|
|
||||||
)
|
|
||||||
|
|
||||||
subnet.network_security_group = None
|
|
||||||
network_client = get_network_client()
|
|
||||||
try:
|
|
||||||
network_client.subnets.begin_create_or_update(
|
|
||||||
resource_group, vnet.name, subnet.name, subnet
|
|
||||||
)
|
|
||||||
except (ResourceNotFoundError, CloudError) as err:
|
|
||||||
if is_concurrent_request_error(str(err)):
|
|
||||||
logging.debug(
|
|
||||||
"dissociate nsg with subnet had conflicts with ",
|
|
||||||
"concurrent request, ignoring %s",
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_UPDATE,
|
|
||||||
errors=[
|
|
||||||
"Unable to dissociate nsg %s with subnet %s due to %s"
|
|
||||||
% (
|
|
||||||
name,
|
|
||||||
subnet.name,
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class NSG(BaseModel):
|
|
||||||
name: str
|
|
||||||
region: Region
|
|
||||||
|
|
||||||
@validator("name", allow_reuse=True)
|
|
||||||
def check_name(cls, value: str) -> str:
|
|
||||||
# https://docs.microsoft.com/en-us/azure/azure-resource-manager/management/resource-name-rules
|
|
||||||
if len(value) > 80:
|
|
||||||
raise ValueError("NSG name too long")
|
|
||||||
return value
|
|
||||||
|
|
||||||
def create(self) -> Union[None, Error]:
|
|
||||||
# Optimization: if NSG exists - do not try
|
|
||||||
# to create it
|
|
||||||
if self.get() is not None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return create_nsg(self.name, self.region)
|
|
||||||
|
|
||||||
def start_delete(self) -> bool:
|
|
||||||
return start_delete_nsg(self.name)
|
|
||||||
|
|
||||||
def get(self) -> Optional[NetworkSecurityGroup]:
|
|
||||||
return get_nsg(self.name)
|
|
||||||
|
|
||||||
def set_allowed_sources(
|
|
||||||
self, sources: NetworkSecurityGroupConfig
|
|
||||||
) -> Union[None, Error]:
|
|
||||||
return set_allowed(self.name, sources)
|
|
||||||
|
|
||||||
def clear_all_rules(self) -> Union[None, Error]:
|
|
||||||
return clear_all_rules(self.name)
|
|
||||||
|
|
||||||
def get_all_rules(self) -> Union[Error, List[SecurityRule]]:
|
|
||||||
return get_all_rules(self.name)
|
|
||||||
|
|
||||||
def associate_nic(self, nic: NetworkInterface) -> Union[None, Error]:
|
|
||||||
return associate_nic(self.name, nic)
|
|
||||||
|
|
||||||
def dissociate_nic(self, nic: NetworkInterface) -> Union[None, Error]:
|
|
||||||
return dissociate_nic(self.name, nic)
|
|
||||||
|
|
||||||
def associate_subnet(
|
|
||||||
self, vnet: VirtualNetwork, subnet: Subnet
|
|
||||||
) -> Union[None, Error]:
|
|
||||||
return associate_subnet(self.name, vnet, subnet)
|
|
||||||
|
|
||||||
def dissociate_subnet(
|
|
||||||
self, vnet: VirtualNetwork, subnet: Subnet
|
|
||||||
) -> Union[None, Error]:
|
|
||||||
return dissociate_subnet(self.name, vnet, subnet)
|
|
@ -1,203 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import datetime
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import List, Optional, Type, TypeVar, Union
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
|
|
||||||
from azure.storage.queue import (
|
|
||||||
QueueSasPermissions,
|
|
||||||
QueueServiceClient,
|
|
||||||
generate_queue_sas,
|
|
||||||
)
|
|
||||||
from memoization import cached
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from .storage import StorageType, get_primary_account, get_storage_account_name_key
|
|
||||||
|
|
||||||
QueueNameType = Union[str, UUID]
|
|
||||||
|
|
||||||
DEFAULT_TTL = -1
|
|
||||||
DEFAULT_DURATION = datetime.timedelta(days=30)
|
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl=60)
|
|
||||||
def get_queue_client(storage_type: StorageType) -> QueueServiceClient:
|
|
||||||
account_id = get_primary_account(storage_type)
|
|
||||||
logging.debug("getting blob container (account_id: %s)", account_id)
|
|
||||||
name, key = get_storage_account_name_key(account_id)
|
|
||||||
account_url = "https://%s.queue.core.windows.net" % name
|
|
||||||
client = QueueServiceClient(
|
|
||||||
account_url=account_url,
|
|
||||||
credential={"account_name": name, "account_key": key},
|
|
||||||
)
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl=60)
|
|
||||||
def get_queue_sas(
|
|
||||||
queue: QueueNameType,
|
|
||||||
storage_type: StorageType,
|
|
||||||
*,
|
|
||||||
read: bool = False,
|
|
||||||
add: bool = False,
|
|
||||||
update: bool = False,
|
|
||||||
process: bool = False,
|
|
||||||
duration: Optional[datetime.timedelta] = None,
|
|
||||||
) -> str:
|
|
||||||
if duration is None:
|
|
||||||
duration = DEFAULT_DURATION
|
|
||||||
account_id = get_primary_account(storage_type)
|
|
||||||
logging.debug("getting queue sas %s (account_id: %s)", queue, account_id)
|
|
||||||
name, key = get_storage_account_name_key(account_id)
|
|
||||||
expiry = datetime.datetime.utcnow() + duration
|
|
||||||
|
|
||||||
token = generate_queue_sas(
|
|
||||||
name,
|
|
||||||
str(queue),
|
|
||||||
key,
|
|
||||||
permission=QueueSasPermissions(
|
|
||||||
read=read, add=add, update=update, process=process
|
|
||||||
),
|
|
||||||
expiry=expiry,
|
|
||||||
)
|
|
||||||
|
|
||||||
url = "https://%s.queue.core.windows.net/%s?%s" % (name, queue, token)
|
|
||||||
return url
|
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl=60)
|
|
||||||
def create_queue(name: QueueNameType, storage_type: StorageType) -> None:
|
|
||||||
client = get_queue_client(storage_type)
|
|
||||||
try:
|
|
||||||
client.create_queue(str(name))
|
|
||||||
except ResourceExistsError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def delete_queue(name: QueueNameType, storage_type: StorageType) -> None:
|
|
||||||
client = get_queue_client(storage_type)
|
|
||||||
queues = client.list_queues()
|
|
||||||
if str(name) in [x["name"] for x in queues]:
|
|
||||||
client.delete_queue(name)
|
|
||||||
|
|
||||||
|
|
||||||
def get_queue(
|
|
||||||
name: QueueNameType, storage_type: StorageType
|
|
||||||
) -> Optional[QueueServiceClient]:
|
|
||||||
client = get_queue_client(storage_type)
|
|
||||||
try:
|
|
||||||
return client.get_queue_client(str(name))
|
|
||||||
except ResourceNotFoundError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def clear_queue(name: QueueNameType, storage_type: StorageType) -> None:
|
|
||||||
queue = get_queue(name, storage_type)
|
|
||||||
if queue:
|
|
||||||
try:
|
|
||||||
queue.clear_messages()
|
|
||||||
except ResourceNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def send_message(
|
|
||||||
name: QueueNameType,
|
|
||||||
message: bytes,
|
|
||||||
storage_type: StorageType,
|
|
||||||
*,
|
|
||||||
visibility_timeout: Optional[int] = None,
|
|
||||||
time_to_live: int = DEFAULT_TTL,
|
|
||||||
) -> None:
|
|
||||||
queue = get_queue(name, storage_type)
|
|
||||||
if queue:
|
|
||||||
try:
|
|
||||||
queue.send_message(
|
|
||||||
base64.b64encode(message).decode(),
|
|
||||||
visibility_timeout=visibility_timeout,
|
|
||||||
time_to_live=time_to_live,
|
|
||||||
)
|
|
||||||
except ResourceNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def remove_first_message(name: QueueNameType, storage_type: StorageType) -> bool:
|
|
||||||
queue = get_queue(name, storage_type)
|
|
||||||
if queue:
|
|
||||||
try:
|
|
||||||
for message in queue.receive_messages():
|
|
||||||
queue.delete_message(message)
|
|
||||||
return True
|
|
||||||
except ResourceNotFoundError:
|
|
||||||
return False
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
A = TypeVar("A", bound=BaseModel)
|
|
||||||
|
|
||||||
|
|
||||||
MIN_PEEK_SIZE = 1
|
|
||||||
MAX_PEEK_SIZE = 32
|
|
||||||
|
|
||||||
|
|
||||||
# Peek at a max of 32 messages
|
|
||||||
# https://docs.microsoft.com/en-us/python/api/azure-storage-queue/azure.storage.queue.queueclient
|
|
||||||
def peek_queue(
|
|
||||||
name: QueueNameType,
|
|
||||||
storage_type: StorageType,
|
|
||||||
*,
|
|
||||||
object_type: Type[A],
|
|
||||||
max_messages: int = MAX_PEEK_SIZE,
|
|
||||||
) -> List[A]:
|
|
||||||
result: List[A] = []
|
|
||||||
|
|
||||||
# message count
|
|
||||||
if max_messages < MIN_PEEK_SIZE or max_messages > MAX_PEEK_SIZE:
|
|
||||||
raise ValueError("invalid max messages: %s" % max_messages)
|
|
||||||
|
|
||||||
try:
|
|
||||||
queue = get_queue(name, storage_type)
|
|
||||||
if not queue:
|
|
||||||
return result
|
|
||||||
|
|
||||||
for message in queue.peek_messages(max_messages=max_messages):
|
|
||||||
decoded = base64.b64decode(message.content)
|
|
||||||
raw = json.loads(decoded)
|
|
||||||
result.append(object_type.parse_obj(raw))
|
|
||||||
except ResourceNotFoundError:
|
|
||||||
return result
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def queue_object(
|
|
||||||
name: QueueNameType,
|
|
||||||
message: BaseModel,
|
|
||||||
storage_type: StorageType,
|
|
||||||
*,
|
|
||||||
visibility_timeout: Optional[int] = None,
|
|
||||||
time_to_live: int = DEFAULT_TTL,
|
|
||||||
) -> bool:
|
|
||||||
queue = get_queue(name, storage_type)
|
|
||||||
if not queue:
|
|
||||||
raise Exception("unable to queue object, no such queue: %s" % queue)
|
|
||||||
|
|
||||||
encoded = base64.b64encode(message.json(exclude_none=True).encode()).decode()
|
|
||||||
try:
|
|
||||||
queue.send_message(
|
|
||||||
encoded, visibility_timeout=visibility_timeout, time_to_live=time_to_live
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
except ResourceNotFoundError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def get_resource_id(queue_name: QueueNameType, storage_type: StorageType) -> str:
|
|
||||||
account_id = get_primary_account(storage_type)
|
|
||||||
resource_uri = "%s/services/queue/queues/%s" % (account_id, queue_name)
|
|
||||||
return resource_uri
|
|
@ -1,120 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
from enum import Enum
|
|
||||||
from typing import List, Tuple, cast
|
|
||||||
|
|
||||||
from azure.mgmt.storage import StorageManagementClient
|
|
||||||
from memoization import cached
|
|
||||||
from msrestazure.tools import parse_resource_id
|
|
||||||
|
|
||||||
from .creds import get_base_resource_group, get_identity, get_subscription
|
|
||||||
|
|
||||||
|
|
||||||
class StorageType(Enum):
|
|
||||||
corpus = "corpus"
|
|
||||||
config = "config"
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def get_mgmt_client() -> StorageManagementClient:
|
|
||||||
return StorageManagementClient(
|
|
||||||
credential=get_identity(), subscription_id=get_subscription()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def get_fuzz_storage() -> str:
|
|
||||||
return os.environ["ONEFUZZ_DATA_STORAGE"]
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def get_func_storage() -> str:
|
|
||||||
return os.environ["ONEFUZZ_FUNC_STORAGE"]
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def get_primary_account(storage_type: StorageType) -> str:
|
|
||||||
if storage_type == StorageType.corpus:
|
|
||||||
# see #322 for discussion about typing
|
|
||||||
return get_fuzz_storage()
|
|
||||||
elif storage_type == StorageType.config:
|
|
||||||
# see #322 for discussion about typing
|
|
||||||
return get_func_storage()
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def get_accounts(storage_type: StorageType) -> List[str]:
|
|
||||||
if storage_type == StorageType.corpus:
|
|
||||||
return corpus_accounts()
|
|
||||||
elif storage_type == StorageType.config:
|
|
||||||
return [get_func_storage()]
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def get_storage_account_name_key(account_id: str) -> Tuple[str, str]:
|
|
||||||
resource = parse_resource_id(account_id)
|
|
||||||
key = get_storage_account_name_key_by_name(resource["name"])
|
|
||||||
return resource["name"], key
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def get_storage_account_name_key_by_name(account_name: str) -> str:
|
|
||||||
client = get_mgmt_client()
|
|
||||||
group = get_base_resource_group()
|
|
||||||
key = client.storage_accounts.list_keys(group, account_name).keys[0].value
|
|
||||||
return cast(str, key)
|
|
||||||
|
|
||||||
|
|
||||||
def choose_account(storage_type: StorageType) -> str:
|
|
||||||
accounts = get_accounts(storage_type)
|
|
||||||
if not accounts:
|
|
||||||
raise Exception(f"no storage accounts for {storage_type}")
|
|
||||||
|
|
||||||
if len(accounts) == 1:
|
|
||||||
return accounts[0]
|
|
||||||
|
|
||||||
# Use a random secondary storage account if any are available. This
|
|
||||||
# reduces IOP contention for the Storage Queues, which are only available
|
|
||||||
# on primary accounts
|
|
||||||
#
|
|
||||||
# security note: this is not used as a security feature
|
|
||||||
return random.choice(accounts[1:]) # nosec
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def corpus_accounts() -> List[str]:
|
|
||||||
skip = get_func_storage()
|
|
||||||
results = [get_fuzz_storage()]
|
|
||||||
|
|
||||||
client = get_mgmt_client()
|
|
||||||
group = get_base_resource_group()
|
|
||||||
for account in client.storage_accounts.list_by_resource_group(group):
|
|
||||||
# protection from someone adding the corpus tag to the config account
|
|
||||||
if account.id == skip:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if account.id in results:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if account.primary_endpoints.blob is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if (
|
|
||||||
"storage_type" not in account.tags
|
|
||||||
or account.tags["storage_type"] != "corpus"
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
results.append(account.id)
|
|
||||||
|
|
||||||
logging.info("corpus accounts: %s", results)
|
|
||||||
return results
|
|
@ -1,97 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from typing import Any, Optional, Union, cast
|
|
||||||
|
|
||||||
from azure.core.exceptions import ResourceNotFoundError
|
|
||||||
from azure.mgmt.network.models import Subnet, VirtualNetwork
|
|
||||||
from msrestazure.azure_exceptions import CloudError
|
|
||||||
from onefuzztypes.enums import ErrorCode
|
|
||||||
from onefuzztypes.models import Error, NetworkConfig
|
|
||||||
from onefuzztypes.primitives import Region
|
|
||||||
|
|
||||||
from .network_mgmt_client import get_network_client
|
|
||||||
|
|
||||||
|
|
||||||
def get_vnet(resource_group: str, name: str) -> Optional[VirtualNetwork]:
|
|
||||||
network_client = get_network_client()
|
|
||||||
try:
|
|
||||||
vnet = network_client.virtual_networks.get(resource_group, name)
|
|
||||||
return cast(VirtualNetwork, vnet)
|
|
||||||
except (CloudError, ResourceNotFoundError):
|
|
||||||
logging.info(
|
|
||||||
"vnet missing: resource group:%s name:%s",
|
|
||||||
resource_group,
|
|
||||||
name,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_subnet(
|
|
||||||
resource_group: str, vnet_name: str, subnet_name: str
|
|
||||||
) -> Optional[Subnet]:
|
|
||||||
# Has to get using vnet. That way NSG field is properly set up in subnet
|
|
||||||
vnet = get_vnet(resource_group, vnet_name)
|
|
||||||
if vnet:
|
|
||||||
for subnet in vnet.subnets:
|
|
||||||
if subnet.name == subnet_name:
|
|
||||||
return subnet
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_subnet_id(resource_group: str, name: str, subnet_name: str) -> Optional[str]:
|
|
||||||
subnet = get_subnet(resource_group, name, subnet_name)
|
|
||||||
if subnet and isinstance(subnet.id, str):
|
|
||||||
return subnet.id
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def delete_subnet(resource_group: str, name: str) -> Union[None, CloudError, Any]:
|
|
||||||
network_client = get_network_client()
|
|
||||||
try:
|
|
||||||
return network_client.virtual_networks.begin_delete(resource_group, name)
|
|
||||||
except (CloudError, ResourceNotFoundError) as err:
|
|
||||||
if err.error and "InUseSubnetCannotBeDeleted" in str(err.error):
|
|
||||||
logging.error(
|
|
||||||
"subnet delete failed: %s %s : %s", resource_group, name, repr(err)
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
raise err
|
|
||||||
|
|
||||||
|
|
||||||
def create_virtual_network(
|
|
||||||
resource_group: str,
|
|
||||||
name: str,
|
|
||||||
region: Region,
|
|
||||||
network_config: NetworkConfig,
|
|
||||||
) -> Optional[Error]:
|
|
||||||
logging.info(
|
|
||||||
"creating subnet - resource group:%s name:%s region:%s",
|
|
||||||
resource_group,
|
|
||||||
name,
|
|
||||||
region,
|
|
||||||
)
|
|
||||||
|
|
||||||
network_client = get_network_client()
|
|
||||||
params = {
|
|
||||||
"location": region,
|
|
||||||
"address_space": {"address_prefixes": [network_config.address_space]},
|
|
||||||
"subnets": [{"name": name, "address_prefix": network_config.subnet}],
|
|
||||||
}
|
|
||||||
if "ONEFUZZ_OWNER" in os.environ:
|
|
||||||
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
|
|
||||||
|
|
||||||
try:
|
|
||||||
network_client.virtual_networks.begin_create_or_update(
|
|
||||||
resource_group, name, params
|
|
||||||
)
|
|
||||||
except (CloudError, ResourceNotFoundError) as err:
|
|
||||||
return Error(code=ErrorCode.UNABLE_TO_CREATE_NETWORK, errors=[str(err)])
|
|
||||||
|
|
||||||
return None
|
|
@ -1,30 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from azure.cosmosdb.table import TableService
|
|
||||||
from memoization import cached
|
|
||||||
|
|
||||||
from .storage import get_storage_account_name_key
|
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl=60)
|
|
||||||
def get_client(
|
|
||||||
table: Optional[str] = None, account_id: Optional[str] = None
|
|
||||||
) -> TableService:
|
|
||||||
if account_id is None:
|
|
||||||
account_id = os.environ["ONEFUZZ_FUNC_STORAGE"]
|
|
||||||
|
|
||||||
logging.debug("getting table account: (account_id: %s)", account_id)
|
|
||||||
name, key = get_storage_account_name_key(account_id)
|
|
||||||
client = TableService(account_name=name, account_key=key)
|
|
||||||
|
|
||||||
if table and not client.exists(table):
|
|
||||||
logging.info("creating missing table %s", table)
|
|
||||||
client.create_table(table, fail_on_exist=False)
|
|
||||||
return client
|
|
@ -1,313 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from typing import Any, Dict, List, Optional, Union, cast
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from azure.core.exceptions import ResourceNotFoundError
|
|
||||||
from azure.mgmt.compute.models import VirtualMachine
|
|
||||||
from msrestazure.azure_exceptions import CloudError
|
|
||||||
from onefuzztypes.enums import OS, ErrorCode
|
|
||||||
from onefuzztypes.models import Authentication, Error
|
|
||||||
from onefuzztypes.primitives import Extension, Region
|
|
||||||
from pydantic import BaseModel, validator
|
|
||||||
|
|
||||||
from .compute import get_compute_client
|
|
||||||
from .creds import get_base_resource_group
|
|
||||||
from .disk import delete_disk, list_disks
|
|
||||||
from .image import get_os
|
|
||||||
from .ip import create_public_nic, delete_ip, delete_nic, get_ip, get_public_nic
|
|
||||||
from .nsg import NSG
|
|
||||||
|
|
||||||
|
|
||||||
def get_vm(name: str) -> Optional[VirtualMachine]:
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
|
|
||||||
logging.debug("getting vm: %s", name)
|
|
||||||
compute_client = get_compute_client()
|
|
||||||
try:
|
|
||||||
return cast(
|
|
||||||
VirtualMachine,
|
|
||||||
compute_client.virtual_machines.get(
|
|
||||||
resource_group, name, expand="instanceView"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except (ResourceNotFoundError, CloudError) as err:
|
|
||||||
logging.debug("vm does not exist %s", err)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def create_vm(
|
|
||||||
name: str,
|
|
||||||
location: Region,
|
|
||||||
vm_sku: str,
|
|
||||||
image: str,
|
|
||||||
password: str,
|
|
||||||
ssh_public_key: str,
|
|
||||||
nsg: Optional[NSG],
|
|
||||||
tags: Optional[Dict[str, str]],
|
|
||||||
) -> Union[None, Error]:
|
|
||||||
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
logging.info("creating vm %s:%s:%s", resource_group, location, name)
|
|
||||||
|
|
||||||
compute_client = get_compute_client()
|
|
||||||
|
|
||||||
nic = get_public_nic(resource_group, name)
|
|
||||||
if nic is None:
|
|
||||||
result = create_public_nic(resource_group, name, location, nsg)
|
|
||||||
if isinstance(result, Error):
|
|
||||||
return result
|
|
||||||
logging.info("waiting on nic creation")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# when public nic is created, VNET must exist at that point
|
|
||||||
# this is logic of get_public_nic function
|
|
||||||
|
|
||||||
if nsg:
|
|
||||||
result = nsg.associate_nic(nic)
|
|
||||||
if isinstance(result, Error):
|
|
||||||
return result
|
|
||||||
|
|
||||||
if image.startswith("/"):
|
|
||||||
image_ref = {"id": image}
|
|
||||||
else:
|
|
||||||
image_val = image.split(":", 4)
|
|
||||||
image_ref = {
|
|
||||||
"publisher": image_val[0],
|
|
||||||
"offer": image_val[1],
|
|
||||||
"sku": image_val[2],
|
|
||||||
"version": image_val[3],
|
|
||||||
}
|
|
||||||
|
|
||||||
params: Dict = {
|
|
||||||
"location": location,
|
|
||||||
"os_profile": {
|
|
||||||
"computer_name": "node",
|
|
||||||
"admin_username": "onefuzz",
|
|
||||||
},
|
|
||||||
"hardware_profile": {"vm_size": vm_sku},
|
|
||||||
"storage_profile": {"image_reference": image_ref},
|
|
||||||
"network_profile": {"network_interfaces": [{"id": nic.id}]},
|
|
||||||
}
|
|
||||||
|
|
||||||
image_os = get_os(location, image)
|
|
||||||
if isinstance(image_os, Error):
|
|
||||||
return image_os
|
|
||||||
|
|
||||||
if image_os == OS.windows:
|
|
||||||
params["os_profile"]["admin_password"] = password
|
|
||||||
|
|
||||||
if image_os == OS.linux:
|
|
||||||
|
|
||||||
params["os_profile"]["linux_configuration"] = {
|
|
||||||
"disable_password_authentication": True,
|
|
||||||
"ssh": {
|
|
||||||
"public_keys": [
|
|
||||||
{
|
|
||||||
"path": "/home/onefuzz/.ssh/authorized_keys",
|
|
||||||
"key_data": ssh_public_key,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if "ONEFUZZ_OWNER" in os.environ:
|
|
||||||
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
|
|
||||||
|
|
||||||
if tags:
|
|
||||||
params["tags"].update(tags.copy())
|
|
||||||
|
|
||||||
try:
|
|
||||||
compute_client.virtual_machines.begin_create_or_update(
|
|
||||||
resource_group, name, params
|
|
||||||
)
|
|
||||||
except (ResourceNotFoundError, CloudError) as err:
|
|
||||||
if "The request failed due to conflict with a concurrent request" in str(err):
|
|
||||||
logging.debug(
|
|
||||||
"create VM had conflicts with concurrent request, ignoring %s", err
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
return Error(code=ErrorCode.VM_CREATE_FAILED, errors=[str(err)])
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_extension(vm_name: str, extension_name: str) -> Optional[Any]:
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
|
|
||||||
logging.debug(
|
|
||||||
"getting extension: %s:%s:%s",
|
|
||||||
resource_group,
|
|
||||||
vm_name,
|
|
||||||
extension_name,
|
|
||||||
)
|
|
||||||
compute_client = get_compute_client()
|
|
||||||
try:
|
|
||||||
return compute_client.virtual_machine_extensions.get(
|
|
||||||
resource_group, vm_name, extension_name
|
|
||||||
)
|
|
||||||
except (ResourceNotFoundError, CloudError) as err:
|
|
||||||
logging.info("extension does not exist %s", err)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def create_extension(vm_name: str, extension: Dict) -> Any:
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
"creating extension: %s:%s:%s", resource_group, vm_name, extension["name"]
|
|
||||||
)
|
|
||||||
compute_client = get_compute_client()
|
|
||||||
return compute_client.virtual_machine_extensions.begin_create_or_update(
|
|
||||||
resource_group, vm_name, extension["name"], extension
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def delete_vm(name: str) -> Any:
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
|
|
||||||
logging.info("deleting vm: %s %s", resource_group, name)
|
|
||||||
compute_client = get_compute_client()
|
|
||||||
return compute_client.virtual_machines.begin_delete(resource_group, name)
|
|
||||||
|
|
||||||
|
|
||||||
def has_components(name: str) -> bool:
|
|
||||||
# check if any of the components associated with a VM still exist.
|
|
||||||
#
|
|
||||||
# Azure VM Deletion requires we first delete the VM, then delete all of it's
|
|
||||||
# resources. This is required to ensure we've cleaned it all up before
|
|
||||||
# marking it "done"
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
if get_vm(name):
|
|
||||||
return True
|
|
||||||
if get_public_nic(resource_group, name):
|
|
||||||
return True
|
|
||||||
if get_ip(resource_group, name):
|
|
||||||
return True
|
|
||||||
|
|
||||||
disks = [x.name for x in list_disks(resource_group) if x.name.startswith(name)]
|
|
||||||
if disks:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def delete_vm_components(name: str, nsg: Optional[NSG]) -> bool:
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
logging.info("deleting vm components %s:%s", resource_group, name)
|
|
||||||
if get_vm(name):
|
|
||||||
logging.info("deleting vm %s:%s", resource_group, name)
|
|
||||||
delete_vm(name)
|
|
||||||
return False
|
|
||||||
|
|
||||||
nic = get_public_nic(resource_group, name)
|
|
||||||
if nic:
|
|
||||||
logging.info("deleting nic %s:%s", resource_group, name)
|
|
||||||
if nic.network_security_group and nsg:
|
|
||||||
nsg.dissociate_nic(nic)
|
|
||||||
return False
|
|
||||||
delete_nic(resource_group, name)
|
|
||||||
return False
|
|
||||||
|
|
||||||
if get_ip(resource_group, name):
|
|
||||||
logging.info("deleting ip %s:%s", resource_group, name)
|
|
||||||
delete_ip(resource_group, name)
|
|
||||||
return False
|
|
||||||
|
|
||||||
disks = [x.name for x in list_disks(resource_group) if x.name.startswith(name)]
|
|
||||||
if disks:
|
|
||||||
for disk in disks:
|
|
||||||
logging.info("deleting disk %s:%s", resource_group, disk)
|
|
||||||
delete_disk(resource_group, disk)
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class VM(BaseModel):
|
|
||||||
name: Union[UUID, str]
|
|
||||||
region: Region
|
|
||||||
sku: str
|
|
||||||
image: str
|
|
||||||
auth: Authentication
|
|
||||||
nsg: Optional[NSG]
|
|
||||||
tags: Optional[Dict[str, str]]
|
|
||||||
|
|
||||||
@validator("name", allow_reuse=True)
|
|
||||||
def check_name(cls, value: Union[UUID, str]) -> Union[UUID, str]:
|
|
||||||
if isinstance(value, str):
|
|
||||||
if len(value) > 40:
|
|
||||||
# Azure truncates resources if the names are longer than 40
|
|
||||||
# bytes
|
|
||||||
raise ValueError("VM name too long")
|
|
||||||
return value
|
|
||||||
|
|
||||||
def is_deleted(self) -> bool:
|
|
||||||
# A VM is considered deleted once all of it's resources including disks,
|
|
||||||
# NICs, IPs, as well as the VM are deleted
|
|
||||||
return not has_components(str(self.name))
|
|
||||||
|
|
||||||
def exists(self) -> bool:
|
|
||||||
return self.get() is not None
|
|
||||||
|
|
||||||
def get(self) -> Optional[VirtualMachine]:
|
|
||||||
return get_vm(str(self.name))
|
|
||||||
|
|
||||||
def create(self) -> Union[None, Error]:
|
|
||||||
if self.get() is not None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
logging.info("vm creating: %s", self.name)
|
|
||||||
return create_vm(
|
|
||||||
str(self.name),
|
|
||||||
self.region,
|
|
||||||
self.sku,
|
|
||||||
self.image,
|
|
||||||
self.auth.password,
|
|
||||||
self.auth.public_key,
|
|
||||||
self.nsg,
|
|
||||||
self.tags,
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete(self) -> bool:
|
|
||||||
return delete_vm_components(str(self.name), self.nsg)
|
|
||||||
|
|
||||||
def add_extensions(self, extensions: List[Extension]) -> Union[bool, Error]:
|
|
||||||
status = []
|
|
||||||
to_create = []
|
|
||||||
for config in extensions:
|
|
||||||
if not isinstance(config["name"], str):
|
|
||||||
logging.error("vm agent - incompatable name: %s", repr(config))
|
|
||||||
continue
|
|
||||||
extension = get_extension(str(self.name), config["name"])
|
|
||||||
|
|
||||||
if extension:
|
|
||||||
logging.info(
|
|
||||||
"vm extension state: %s - %s - %s",
|
|
||||||
self.name,
|
|
||||||
config["name"],
|
|
||||||
extension.provisioning_state,
|
|
||||||
)
|
|
||||||
status.append(extension.provisioning_state)
|
|
||||||
else:
|
|
||||||
to_create.append(config)
|
|
||||||
|
|
||||||
if to_create:
|
|
||||||
for config in to_create:
|
|
||||||
create_extension(str(self.name), config)
|
|
||||||
else:
|
|
||||||
if all([x == "Succeeded" for x in status]):
|
|
||||||
return True
|
|
||||||
elif "Failed" in status:
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.VM_CREATE_FAILED,
|
|
||||||
errors=["failed to launch extension"],
|
|
||||||
)
|
|
||||||
elif not ("Creating" in status or "Updating" in status):
|
|
||||||
logging.error("vm agent - unknown state %s: %s", self.name, status)
|
|
||||||
|
|
||||||
return False
|
|
@ -1,541 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set, Union, cast
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from azure.core.exceptions import (
|
|
||||||
HttpResponseError,
|
|
||||||
ResourceExistsError,
|
|
||||||
ResourceNotFoundError,
|
|
||||||
)
|
|
||||||
from azure.mgmt.compute.models import (
|
|
||||||
ResourceSku,
|
|
||||||
ResourceSkuRestrictionsType,
|
|
||||||
VirtualMachineScaleSetVMInstanceIDs,
|
|
||||||
VirtualMachineScaleSetVMInstanceRequiredIDs,
|
|
||||||
VirtualMachineScaleSetVMListResult,
|
|
||||||
VirtualMachineScaleSetVMProtectionPolicy,
|
|
||||||
)
|
|
||||||
from memoization import cached
|
|
||||||
from msrestazure.azure_exceptions import CloudError
|
|
||||||
from onefuzztypes.enums import OS, ErrorCode
|
|
||||||
from onefuzztypes.models import Error
|
|
||||||
from onefuzztypes.primitives import Region
|
|
||||||
|
|
||||||
from .compute import get_compute_client
|
|
||||||
from .creds import (
|
|
||||||
get_base_resource_group,
|
|
||||||
get_scaleset_identity_resource_path,
|
|
||||||
retry_on_auth_failure,
|
|
||||||
)
|
|
||||||
from .image import get_os
|
|
||||||
|
|
||||||
|
|
||||||
@retry_on_auth_failure()
|
|
||||||
def list_vmss(
|
|
||||||
name: UUID,
|
|
||||||
vm_filter: Optional[Callable[[VirtualMachineScaleSetVMListResult], bool]] = None,
|
|
||||||
) -> Optional[List[str]]:
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
client = get_compute_client()
|
|
||||||
try:
|
|
||||||
instances = [
|
|
||||||
x.instance_id
|
|
||||||
for x in client.virtual_machine_scale_set_vms.list(
|
|
||||||
resource_group, str(name)
|
|
||||||
)
|
|
||||||
if vm_filter is None or vm_filter(x)
|
|
||||||
]
|
|
||||||
return instances
|
|
||||||
except (ResourceNotFoundError, CloudError) as err:
|
|
||||||
logging.error("cloud error listing vmss: %s (%s)", name, err)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@retry_on_auth_failure()
|
|
||||||
def delete_vmss(name: UUID) -> bool:
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
compute_client = get_compute_client()
|
|
||||||
response = compute_client.virtual_machine_scale_sets.begin_delete(
|
|
||||||
resource_group, str(name)
|
|
||||||
)
|
|
||||||
|
|
||||||
# https://docs.microsoft.com/en-us/python/api/azure-core/
|
|
||||||
# azure.core.polling.lropoller?view=azure-python#status--
|
|
||||||
#
|
|
||||||
# status returns a str, however mypy thinks this is an Any.
|
|
||||||
#
|
|
||||||
# Checked by hand that the result is Succeeded in practice
|
|
||||||
return bool(response.status() == "Succeeded")
|
|
||||||
|
|
||||||
|
|
||||||
@retry_on_auth_failure()
|
|
||||||
def get_vmss(name: UUID) -> Optional[Any]:
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
logging.debug("getting vm: %s", name)
|
|
||||||
compute_client = get_compute_client()
|
|
||||||
try:
|
|
||||||
return compute_client.virtual_machine_scale_sets.get(resource_group, str(name))
|
|
||||||
except ResourceNotFoundError as err:
|
|
||||||
logging.debug("vm does not exist %s", err)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@retry_on_auth_failure()
|
|
||||||
def resize_vmss(name: UUID, capacity: int) -> None:
|
|
||||||
check_can_update(name)
|
|
||||||
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
logging.info("updating VM count - name: %s vm_count: %d", name, capacity)
|
|
||||||
compute_client = get_compute_client()
|
|
||||||
try:
|
|
||||||
compute_client.virtual_machine_scale_sets.begin_update(
|
|
||||||
resource_group, str(name), {"sku": {"capacity": capacity}}
|
|
||||||
)
|
|
||||||
except ResourceExistsError as err:
|
|
||||||
logging.error(
|
|
||||||
"unable to resize scaleset. name:%s vm_count:%d - err:%s",
|
|
||||||
name,
|
|
||||||
capacity,
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
except HttpResponseError as err:
|
|
||||||
if (
|
|
||||||
"models that may be referenced by one or more"
|
|
||||||
+ " VMs belonging to the Virtual Machine Scale Set"
|
|
||||||
in str(err)
|
|
||||||
):
|
|
||||||
logging.error(
|
|
||||||
"unable to resize scaleset due to model error. name: %s - err:%s",
|
|
||||||
name,
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@retry_on_auth_failure()
|
|
||||||
def get_vmss_size(name: UUID) -> Optional[int]:
|
|
||||||
vmss = get_vmss(name)
|
|
||||||
if vmss is None:
|
|
||||||
return None
|
|
||||||
return cast(int, vmss.sku.capacity)
|
|
||||||
|
|
||||||
|
|
||||||
@retry_on_auth_failure()
|
|
||||||
def list_instance_ids(name: UUID) -> Dict[UUID, str]:
|
|
||||||
logging.debug("get instance IDs for scaleset: %s", name)
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
compute_client = get_compute_client()
|
|
||||||
|
|
||||||
results = {}
|
|
||||||
try:
|
|
||||||
for instance in compute_client.virtual_machine_scale_set_vms.list(
|
|
||||||
resource_group, str(name)
|
|
||||||
):
|
|
||||||
results[UUID(instance.vm_id)] = cast(str, instance.instance_id)
|
|
||||||
except (ResourceNotFoundError, CloudError):
|
|
||||||
logging.debug("vm does not exist %s", name)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl=60)
|
|
||||||
@retry_on_auth_failure()
|
|
||||||
def get_instance_id(name: UUID, vm_id: UUID) -> Union[str, Error]:
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
logging.info("get instance ID for scaleset node: %s:%s", name, vm_id)
|
|
||||||
compute_client = get_compute_client()
|
|
||||||
|
|
||||||
vm_id_str = str(vm_id)
|
|
||||||
for instance in compute_client.virtual_machine_scale_set_vms.list(
|
|
||||||
resource_group, str(name)
|
|
||||||
):
|
|
||||||
if instance.vm_id == vm_id_str:
|
|
||||||
return cast(str, instance.instance_id)
|
|
||||||
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_FIND,
|
|
||||||
errors=["unable to find scaleset machine: %s:%s" % (name, vm_id)],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@retry_on_auth_failure()
|
|
||||||
def update_scale_in_protection(
|
|
||||||
name: UUID, vm_id: UUID, protect_from_scale_in: bool
|
|
||||||
) -> Optional[Error]:
|
|
||||||
instance_id = get_instance_id(name, vm_id)
|
|
||||||
|
|
||||||
if isinstance(instance_id, Error):
|
|
||||||
return instance_id
|
|
||||||
|
|
||||||
compute_client = get_compute_client()
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
|
|
||||||
try:
|
|
||||||
instance_vm = compute_client.virtual_machine_scale_set_vms.get(
|
|
||||||
resource_group, name, instance_id
|
|
||||||
)
|
|
||||||
except (ResourceNotFoundError, CloudError):
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_FIND,
|
|
||||||
errors=["unable to find vm instance: %s:%s" % (name, instance_id)],
|
|
||||||
)
|
|
||||||
|
|
||||||
new_protection_policy = VirtualMachineScaleSetVMProtectionPolicy(
|
|
||||||
protect_from_scale_in=protect_from_scale_in
|
|
||||||
)
|
|
||||||
if instance_vm.protection_policy is not None:
|
|
||||||
new_protection_policy = instance_vm.protection_policy
|
|
||||||
new_protection_policy.protect_from_scale_in = protect_from_scale_in
|
|
||||||
|
|
||||||
instance_vm.protection_policy = new_protection_policy
|
|
||||||
|
|
||||||
try:
|
|
||||||
compute_client.virtual_machine_scale_set_vms.begin_update(
|
|
||||||
resource_group, name, instance_id, instance_vm
|
|
||||||
)
|
|
||||||
except (ResourceNotFoundError, CloudError, HttpResponseError) as err:
|
|
||||||
if isinstance(err, HttpResponseError):
|
|
||||||
err_str = str(err)
|
|
||||||
instance_not_found = (
|
|
||||||
" is not an active Virtual Machine Scale Set VM instanceId."
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
instance_not_found in err_str
|
|
||||||
and instance_vm.protection_policy.protect_from_scale_in is False
|
|
||||||
and protect_from_scale_in
|
|
||||||
== instance_vm.protection_policy.protect_from_scale_in
|
|
||||||
):
|
|
||||||
logging.info(
|
|
||||||
"Tried to remove scale in protection on node %s but the instance no longer exists" # noqa: E501
|
|
||||||
% instance_id
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
if (
|
|
||||||
"models that may be referenced by one or more"
|
|
||||||
+ " VMs belonging to the Virtual Machine Scale Set"
|
|
||||||
in str(err)
|
|
||||||
):
|
|
||||||
logging.error(
|
|
||||||
"unable to resize scaleset due to model error. name: %s - err:%s",
|
|
||||||
name,
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_UPDATE,
|
|
||||||
errors=["unable to set protection policy on: %s:%s" % (vm_id, instance_id)],
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
"Successfully set scale in protection on node %s to %s"
|
|
||||||
% (vm_id, protect_from_scale_in)
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class UnableToUpdate(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def check_can_update(name: UUID) -> Any:
|
|
||||||
vmss = get_vmss(name)
|
|
||||||
if vmss is None:
|
|
||||||
raise UnableToUpdate
|
|
||||||
|
|
||||||
if vmss.provisioning_state == "Updating":
|
|
||||||
raise UnableToUpdate
|
|
||||||
|
|
||||||
return vmss
|
|
||||||
|
|
||||||
|
|
||||||
@retry_on_auth_failure()
|
|
||||||
def reimage_vmss_nodes(name: UUID, vm_ids: Set[UUID]) -> Optional[Error]:
|
|
||||||
check_can_update(name)
|
|
||||||
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
compute_client = get_compute_client()
|
|
||||||
|
|
||||||
instance_ids = set()
|
|
||||||
machine_to_id = list_instance_ids(name)
|
|
||||||
for vm_id in vm_ids:
|
|
||||||
if vm_id in machine_to_id:
|
|
||||||
instance_ids.add(machine_to_id[vm_id])
|
|
||||||
else:
|
|
||||||
logging.info("unable to find vm_id for %s:%s", name, vm_id)
|
|
||||||
|
|
||||||
# Nodes that must be are 'upgraded' before the reimage. This call makes sure
|
|
||||||
# the instance is up-to-date with the VMSS model.
|
|
||||||
# The expectation is that these requests are queued and handled subsequently.
|
|
||||||
# The VMSS Team confirmed this expectation and testing supports it, as well.
|
|
||||||
if instance_ids:
|
|
||||||
logging.info("upgrading VMSS nodes - name: %s vm_ids: %s", name, vm_id)
|
|
||||||
compute_client.virtual_machine_scale_sets.begin_update_instances(
|
|
||||||
resource_group,
|
|
||||||
str(name),
|
|
||||||
VirtualMachineScaleSetVMInstanceIDs(instance_ids=list(instance_ids)),
|
|
||||||
)
|
|
||||||
logging.info("reimaging VMSS nodes - name: %s vm_ids: %s", name, vm_id)
|
|
||||||
compute_client.virtual_machine_scale_sets.begin_reimage_all(
|
|
||||||
resource_group,
|
|
||||||
str(name),
|
|
||||||
VirtualMachineScaleSetVMInstanceIDs(instance_ids=list(instance_ids)),
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@retry_on_auth_failure()
|
|
||||||
def delete_vmss_nodes(name: UUID, vm_ids: Set[UUID]) -> Optional[Error]:
|
|
||||||
check_can_update(name)
|
|
||||||
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
logging.info("deleting scaleset VM - name: %s vm_ids:%s", name, vm_ids)
|
|
||||||
compute_client = get_compute_client()
|
|
||||||
|
|
||||||
instance_ids = set()
|
|
||||||
machine_to_id = list_instance_ids(name)
|
|
||||||
for vm_id in vm_ids:
|
|
||||||
if vm_id in machine_to_id:
|
|
||||||
instance_ids.add(machine_to_id[vm_id])
|
|
||||||
else:
|
|
||||||
logging.info("unable to find vm_id for %s:%s", name, vm_id)
|
|
||||||
|
|
||||||
if instance_ids:
|
|
||||||
compute_client.virtual_machine_scale_sets.begin_delete_instances(
|
|
||||||
resource_group,
|
|
||||||
str(name),
|
|
||||||
VirtualMachineScaleSetVMInstanceRequiredIDs(
|
|
||||||
instance_ids=list(instance_ids)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@retry_on_auth_failure()
|
|
||||||
def update_extensions(name: UUID, extensions: List[Any]) -> None:
|
|
||||||
check_can_update(name)
|
|
||||||
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
logging.info("updating VM extensions: %s", name)
|
|
||||||
compute_client = get_compute_client()
|
|
||||||
try:
|
|
||||||
compute_client.virtual_machine_scale_sets.begin_update(
|
|
||||||
resource_group,
|
|
||||||
str(name),
|
|
||||||
{
|
|
||||||
"virtual_machine_profile": {
|
|
||||||
"extension_profile": {"extensions": extensions}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
logging.info("VM extensions updated: %s", name)
|
|
||||||
except HttpResponseError as err:
|
|
||||||
if (
|
|
||||||
"models that may be referenced by one or more"
|
|
||||||
+ " VMs belonging to the Virtual Machine Scale Set"
|
|
||||||
in str(err)
|
|
||||||
):
|
|
||||||
logging.error(
|
|
||||||
"unable to resize scaleset due to model error. name: %s - err:%s",
|
|
||||||
name,
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@retry_on_auth_failure()
|
|
||||||
def create_vmss(
|
|
||||||
location: Region,
|
|
||||||
name: UUID,
|
|
||||||
vm_sku: str,
|
|
||||||
vm_count: int,
|
|
||||||
image: str,
|
|
||||||
network_id: str,
|
|
||||||
spot_instances: bool,
|
|
||||||
ephemeral_os_disks: bool,
|
|
||||||
extensions: List[Any],
|
|
||||||
password: str,
|
|
||||||
ssh_public_key: str,
|
|
||||||
tags: Dict[str, str],
|
|
||||||
) -> Optional[Error]:
|
|
||||||
|
|
||||||
vmss = get_vmss(name)
|
|
||||||
if vmss is not None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
"creating VM "
|
|
||||||
"name: %s vm_sku: %s vm_count: %d "
|
|
||||||
"image: %s subnet: %s spot_instances: %s",
|
|
||||||
name,
|
|
||||||
vm_sku,
|
|
||||||
vm_count,
|
|
||||||
image,
|
|
||||||
network_id,
|
|
||||||
spot_instances,
|
|
||||||
)
|
|
||||||
|
|
||||||
resource_group = get_base_resource_group()
|
|
||||||
|
|
||||||
compute_client = get_compute_client()
|
|
||||||
|
|
||||||
if image.startswith("/"):
|
|
||||||
image_ref = {"id": image}
|
|
||||||
else:
|
|
||||||
image_val = image.split(":", 4)
|
|
||||||
image_ref = {
|
|
||||||
"publisher": image_val[0],
|
|
||||||
"offer": image_val[1],
|
|
||||||
"sku": image_val[2],
|
|
||||||
"version": image_val[3],
|
|
||||||
}
|
|
||||||
|
|
||||||
sku = {"name": vm_sku, "tier": "Standard", "capacity": vm_count}
|
|
||||||
|
|
||||||
params: Dict[str, Any] = {
|
|
||||||
"location": location,
|
|
||||||
"do_not_run_extensions_on_overprovisioned_vms": True,
|
|
||||||
"upgrade_policy": {"mode": "Manual"},
|
|
||||||
"sku": sku,
|
|
||||||
"overprovision": False,
|
|
||||||
"identity": {
|
|
||||||
"type": "userAssigned",
|
|
||||||
"userAssignedIdentities": {get_scaleset_identity_resource_path(): {}},
|
|
||||||
},
|
|
||||||
"virtual_machine_profile": {
|
|
||||||
"priority": "Regular",
|
|
||||||
"storage_profile": {
|
|
||||||
"image_reference": image_ref,
|
|
||||||
},
|
|
||||||
"os_profile": {
|
|
||||||
"computer_name_prefix": "node",
|
|
||||||
"admin_username": "onefuzz",
|
|
||||||
},
|
|
||||||
"network_profile": {
|
|
||||||
"network_interface_configurations": [
|
|
||||||
{
|
|
||||||
"name": "onefuzz-nic",
|
|
||||||
"primary": True,
|
|
||||||
"ip_configurations": [
|
|
||||||
{"name": "onefuzz-ip-config", "subnet": {"id": network_id}}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"extension_profile": {"extensions": extensions},
|
|
||||||
},
|
|
||||||
"single_placement_group": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
image_os = get_os(location, image)
|
|
||||||
if isinstance(image_os, Error):
|
|
||||||
return image_os
|
|
||||||
|
|
||||||
if image_os == OS.windows:
|
|
||||||
params["virtual_machine_profile"]["os_profile"]["admin_password"] = password
|
|
||||||
|
|
||||||
if image_os == OS.linux:
|
|
||||||
params["virtual_machine_profile"]["os_profile"]["linux_configuration"] = {
|
|
||||||
"disable_password_authentication": True,
|
|
||||||
"ssh": {
|
|
||||||
"public_keys": [
|
|
||||||
{
|
|
||||||
"path": "/home/onefuzz/.ssh/authorized_keys",
|
|
||||||
"key_data": ssh_public_key,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if ephemeral_os_disks:
|
|
||||||
params["virtual_machine_profile"]["storage_profile"]["os_disk"] = {
|
|
||||||
"diffDiskSettings": {"option": "Local"},
|
|
||||||
"caching": "ReadOnly",
|
|
||||||
"createOption": "FromImage",
|
|
||||||
}
|
|
||||||
|
|
||||||
if spot_instances:
|
|
||||||
# Setting max price to -1 means it won't be evicted because of
|
|
||||||
# price.
|
|
||||||
#
|
|
||||||
# https://docs.microsoft.com/en-us/azure/
|
|
||||||
# virtual-machine-scale-sets/use-spot#resource-manager-templates
|
|
||||||
params["virtual_machine_profile"].update(
|
|
||||||
{
|
|
||||||
"eviction_policy": "Delete",
|
|
||||||
"priority": "Spot",
|
|
||||||
"billing_profile": {"max_price": -1},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
params["tags"] = tags.copy()
|
|
||||||
|
|
||||||
owner = os.environ.get("ONEFUZZ_OWNER")
|
|
||||||
if owner:
|
|
||||||
params["tags"]["OWNER"] = owner
|
|
||||||
|
|
||||||
try:
|
|
||||||
compute_client.virtual_machine_scale_sets.begin_create_or_update(
|
|
||||||
resource_group, name, params
|
|
||||||
)
|
|
||||||
except ResourceExistsError as err:
|
|
||||||
err_str = str(err)
|
|
||||||
if "SkuNotAvailable" in err_str or "OperationNotAllowed" in err_str:
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.VM_CREATE_FAILED, errors=[f"creating vmss: {err_str}"]
|
|
||||||
)
|
|
||||||
raise err
|
|
||||||
except (ResourceNotFoundError, CloudError) as err:
|
|
||||||
if "The request failed due to conflict with a concurrent request" in repr(err):
|
|
||||||
logging.debug(
|
|
||||||
"create VM had conflicts with concurrent request, ignoring %s", err
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.VM_CREATE_FAILED,
|
|
||||||
errors=["creating vmss: %s" % err],
|
|
||||||
)
|
|
||||||
except HttpResponseError as err:
|
|
||||||
err_str = str(err)
|
|
||||||
# Catch Gen2 Hypervisor / Image mismatch errors
|
|
||||||
# See https://github.com/microsoft/lisa/pull/716 for an example
|
|
||||||
if (
|
|
||||||
"check that the Hypervisor Generation of the Image matches the "
|
|
||||||
"Hypervisor Generation of the selected VM Size"
|
|
||||||
) in err_str:
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.VM_CREATE_FAILED, errors=[f"creating vmss: {err_str}"]
|
|
||||||
)
|
|
||||||
raise err
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl=60)
|
|
||||||
@retry_on_auth_failure()
|
|
||||||
def list_available_skus(region: Region) -> List[str]:
|
|
||||||
compute_client = get_compute_client()
|
|
||||||
|
|
||||||
skus: List[ResourceSku] = list(
|
|
||||||
compute_client.resource_skus.list(filter="location eq '%s'" % region)
|
|
||||||
)
|
|
||||||
sku_names: List[str] = []
|
|
||||||
for sku in skus:
|
|
||||||
available = True
|
|
||||||
if sku.restrictions is not None:
|
|
||||||
for restriction in sku.restrictions:
|
|
||||||
if restriction.type == ResourceSkuRestrictionsType.location and (
|
|
||||||
region.upper() in [v.upper() for v in restriction.values]
|
|
||||||
):
|
|
||||||
available = False
|
|
||||||
break
|
|
||||||
|
|
||||||
if available:
|
|
||||||
sku_names.append(sku.name)
|
|
||||||
return sku_names
|
|
@ -1,34 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
from onefuzztypes.events import EventInstanceConfigUpdated
|
|
||||||
from onefuzztypes.models import InstanceConfig as BASE_CONFIG
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from .azure.creds import get_instance_name
|
|
||||||
from .events import send_event
|
|
||||||
from .orm import ORMMixin
|
|
||||||
|
|
||||||
|
|
||||||
class InstanceConfig(BASE_CONFIG, ORMMixin):
|
|
||||||
instance_name: str = Field(default_factory=get_instance_name)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def key_fields(cls) -> Tuple[str, Optional[str]]:
|
|
||||||
return ("instance_name", None)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def fetch(cls) -> "InstanceConfig":
|
|
||||||
entry = cls.get(get_instance_name())
|
|
||||||
if entry is None:
|
|
||||||
entry = cls(allowed_aad_tenants=[])
|
|
||||||
entry.save()
|
|
||||||
return entry
|
|
||||||
|
|
||||||
def save(self, new: bool = False, require_etag: bool = False) -> None:
|
|
||||||
super().save(new=new, require_etag=require_etag)
|
|
||||||
send_event(EventInstanceConfigUpdated(config=self))
|
|
@ -1,204 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import urllib
|
|
||||||
from typing import Callable, Optional
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
import azure.functions as func
|
|
||||||
from azure.functions import HttpRequest
|
|
||||||
from memoization import cached
|
|
||||||
from onefuzztypes.enums import ErrorCode
|
|
||||||
from onefuzztypes.models import Error, UserInfo
|
|
||||||
|
|
||||||
from .azure.creds import get_scaleset_principal_id
|
|
||||||
from .azure.group_membership import create_group_membership_checker
|
|
||||||
from .config import InstanceConfig
|
|
||||||
from .request import not_ok
|
|
||||||
from .request_access import RequestAccess
|
|
||||||
from .user_credentials import parse_jwt_token
|
|
||||||
from .workers.pools import Pool
|
|
||||||
from .workers.scalesets import Scaleset
|
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl=60)
|
|
||||||
def get_rules() -> Optional[RequestAccess]:
|
|
||||||
config = InstanceConfig.fetch()
|
|
||||||
if config.api_access_rules:
|
|
||||||
return RequestAccess.build(config.api_access_rules)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def check_access(req: HttpRequest) -> Optional[Error]:
|
|
||||||
rules = get_rules()
|
|
||||||
|
|
||||||
# Nothing to enforce if there are no rules.
|
|
||||||
if not rules:
|
|
||||||
return None
|
|
||||||
|
|
||||||
path = urllib.parse.urlparse(req.url).path
|
|
||||||
rule = rules.get_matching_rules(req.method, path)
|
|
||||||
|
|
||||||
# No restriction defined on this endpoint.
|
|
||||||
if not rule:
|
|
||||||
return None
|
|
||||||
|
|
||||||
member_id = UUID(req.headers["x-ms-client-principal-id"])
|
|
||||||
|
|
||||||
try:
|
|
||||||
membership_checker = create_group_membership_checker()
|
|
||||||
allowed = membership_checker.is_member(rule.allowed_groups_ids, member_id)
|
|
||||||
if not allowed:
|
|
||||||
logging.error(
|
|
||||||
"unauthorized access: %s is not authorized to access in %s",
|
|
||||||
member_id,
|
|
||||||
req.url,
|
|
||||||
)
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNAUTHORIZED,
|
|
||||||
errors=["not approved to use this endpoint"],
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNAUTHORIZED,
|
|
||||||
errors=["unable to interact with graph", str(e)],
|
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl=60)
|
|
||||||
def is_agent(token_data: UserInfo) -> bool:
|
|
||||||
|
|
||||||
if token_data.object_id:
|
|
||||||
# backward compatibility case for scalesets deployed before the migration
|
|
||||||
# to user assigned managed id
|
|
||||||
scalesets = Scaleset.get_by_object_id(token_data.object_id)
|
|
||||||
if len(scalesets) > 0:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# verify object_id against the user assigned managed identity
|
|
||||||
principal_id: UUID = get_scaleset_principal_id()
|
|
||||||
return principal_id == token_data.object_id
|
|
||||||
|
|
||||||
if not token_data.application_id:
|
|
||||||
return False
|
|
||||||
|
|
||||||
pools = Pool.search(query={"client_id": [token_data.application_id]})
|
|
||||||
if len(pools) > 0:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def can_modify_config_impl(config: InstanceConfig, user_info: UserInfo) -> bool:
|
|
||||||
if config.admins is None:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return user_info.object_id in config.admins
|
|
||||||
|
|
||||||
|
|
||||||
def can_modify_config(req: func.HttpRequest, config: InstanceConfig) -> bool:
|
|
||||||
user_info = parse_jwt_token(req)
|
|
||||||
if not isinstance(user_info, UserInfo):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return can_modify_config_impl(config, user_info)
|
|
||||||
|
|
||||||
|
|
||||||
def check_require_admins_impl(
|
|
||||||
config: InstanceConfig, user_info: UserInfo
|
|
||||||
) -> Optional[Error]:
|
|
||||||
if not config.require_admin_privileges:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if config.admins is None:
|
|
||||||
return Error(code=ErrorCode.UNAUTHORIZED, errors=["pool modification disabled"])
|
|
||||||
|
|
||||||
if user_info.object_id in config.admins:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return Error(code=ErrorCode.UNAUTHORIZED, errors=["not authorized to manage pools"])
|
|
||||||
|
|
||||||
|
|
||||||
def check_require_admins(req: func.HttpRequest) -> Optional[Error]:
|
|
||||||
user_info = parse_jwt_token(req)
|
|
||||||
if isinstance(user_info, Error):
|
|
||||||
return user_info
|
|
||||||
|
|
||||||
# When there are no admins in the `admins` list, all users are considered
|
|
||||||
# admins. However, `require_admin_privileges` is still useful to protect from
|
|
||||||
# mistakes.
|
|
||||||
#
|
|
||||||
# To make changes while still protecting against accidental changes to
|
|
||||||
# pools, do the following:
|
|
||||||
#
|
|
||||||
# 1. set `require_admin_privileges` to `False`
|
|
||||||
# 2. make the change
|
|
||||||
# 3. set `require_admin_privileges` to `True`
|
|
||||||
|
|
||||||
config = InstanceConfig.fetch()
|
|
||||||
|
|
||||||
return check_require_admins_impl(config, user_info)
|
|
||||||
|
|
||||||
|
|
||||||
def is_user(token_data: UserInfo) -> bool:
|
|
||||||
return not is_agent(token_data)
|
|
||||||
|
|
||||||
|
|
||||||
def reject(req: func.HttpRequest, token: UserInfo) -> func.HttpResponse:
|
|
||||||
logging.error(
|
|
||||||
"reject token. url:%s token:%s body:%s",
|
|
||||||
repr(req.url),
|
|
||||||
repr(token),
|
|
||||||
repr(req.get_body()),
|
|
||||||
)
|
|
||||||
return not_ok(
|
|
||||||
Error(code=ErrorCode.UNAUTHORIZED, errors=["Unrecognized agent"]),
|
|
||||||
status_code=401,
|
|
||||||
context="token verification",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def call_if(
|
|
||||||
req: func.HttpRequest,
|
|
||||||
method: Callable[[func.HttpRequest], func.HttpResponse],
|
|
||||||
*,
|
|
||||||
allow_user: bool = False,
|
|
||||||
allow_agent: bool = False
|
|
||||||
) -> func.HttpResponse:
|
|
||||||
|
|
||||||
token = parse_jwt_token(req)
|
|
||||||
if isinstance(token, Error):
|
|
||||||
return not_ok(token, status_code=401, context="token verification")
|
|
||||||
|
|
||||||
if is_user(token):
|
|
||||||
if not allow_user:
|
|
||||||
return reject(req, token)
|
|
||||||
|
|
||||||
access = check_access(req)
|
|
||||||
if isinstance(access, Error):
|
|
||||||
return not_ok(access, status_code=401, context="access control")
|
|
||||||
|
|
||||||
if is_agent(token) and not allow_agent:
|
|
||||||
return reject(req, token)
|
|
||||||
|
|
||||||
return method(req)
|
|
||||||
|
|
||||||
|
|
||||||
def call_if_user(
|
|
||||||
req: func.HttpRequest, method: Callable[[func.HttpRequest], func.HttpResponse]
|
|
||||||
) -> func.HttpResponse:
|
|
||||||
|
|
||||||
return call_if(req, method, allow_user=True)
|
|
||||||
|
|
||||||
|
|
||||||
def call_if_agent(
|
|
||||||
req: func.HttpRequest, method: Callable[[func.HttpRequest], func.HttpResponse]
|
|
||||||
) -> func.HttpResponse:
|
|
||||||
|
|
||||||
return call_if(req, method, allow_agent=True)
|
|
@ -1,81 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from onefuzztypes.events import Event, EventMessage, EventType, get_event_type
|
|
||||||
from onefuzztypes.models import UserInfo
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from .azure.creds import get_instance_id, get_instance_name
|
|
||||||
from .azure.queue import send_message
|
|
||||||
from .azure.storage import StorageType
|
|
||||||
from .webhooks import Webhook
|
|
||||||
|
|
||||||
|
|
||||||
class SignalREvent(BaseModel):
|
|
||||||
target: str
|
|
||||||
arguments: List[EventMessage]
|
|
||||||
|
|
||||||
|
|
||||||
def queue_signalr_event(event_message: EventMessage) -> None:
|
|
||||||
message = SignalREvent(target="events", arguments=[event_message]).json().encode()
|
|
||||||
send_message("signalr-events", message, StorageType.config)
|
|
||||||
|
|
||||||
|
|
||||||
def log_event(event: Event, event_type: EventType) -> None:
|
|
||||||
scrubbed_event = filter_event(event)
|
|
||||||
logging.info(
|
|
||||||
"sending event: %s - %s", event_type, scrubbed_event.json(exclude_none=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def filter_event(event: Event) -> BaseModel:
|
|
||||||
clone_event = event.copy(deep=True)
|
|
||||||
filtered_event = filter_event_recurse(clone_event)
|
|
||||||
return filtered_event
|
|
||||||
|
|
||||||
|
|
||||||
def filter_event_recurse(entry: BaseModel) -> BaseModel:
|
|
||||||
|
|
||||||
for field in entry.__fields__:
|
|
||||||
field_data = getattr(entry, field)
|
|
||||||
|
|
||||||
if isinstance(field_data, UserInfo):
|
|
||||||
field_data = None
|
|
||||||
elif isinstance(field_data, list):
|
|
||||||
for (i, value) in enumerate(field_data):
|
|
||||||
if isinstance(value, BaseModel):
|
|
||||||
field_data[i] = filter_event_recurse(value)
|
|
||||||
elif isinstance(field_data, dict):
|
|
||||||
for (key, value) in field_data.items():
|
|
||||||
if isinstance(value, BaseModel):
|
|
||||||
field_data[key] = filter_event_recurse(value)
|
|
||||||
elif isinstance(field_data, BaseModel):
|
|
||||||
field_data = filter_event_recurse(field_data)
|
|
||||||
|
|
||||||
setattr(entry, field, field_data)
|
|
||||||
|
|
||||||
return entry
|
|
||||||
|
|
||||||
|
|
||||||
def send_event(event: Event) -> None:
|
|
||||||
event_type = get_event_type(event)
|
|
||||||
|
|
||||||
event_message = EventMessage(
|
|
||||||
event_type=event_type,
|
|
||||||
event=event.copy(deep=True),
|
|
||||||
instance_id=get_instance_id(),
|
|
||||||
instance_name=get_instance_name(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# work around odd bug with Event Message creation. See PR 939
|
|
||||||
if event_message.event != event:
|
|
||||||
event_message.event = event.copy(deep=True)
|
|
||||||
|
|
||||||
log_event(event, event_type)
|
|
||||||
queue_signalr_event(event_message)
|
|
||||||
Webhook.send_event(event_message)
|
|
@ -1,515 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import List, Optional
|
|
||||||
from uuid import UUID, uuid4
|
|
||||||
|
|
||||||
from onefuzztypes.enums import OS, AgentMode
|
|
||||||
from onefuzztypes.models import (
|
|
||||||
AgentConfig,
|
|
||||||
AzureMonitorExtensionConfig,
|
|
||||||
KeyvaultExtensionConfig,
|
|
||||||
Pool,
|
|
||||||
ReproConfig,
|
|
||||||
Scaleset,
|
|
||||||
)
|
|
||||||
from onefuzztypes.primitives import Container, Extension, Region
|
|
||||||
|
|
||||||
from .azure.containers import (
|
|
||||||
get_container_sas_url,
|
|
||||||
get_file_sas_url,
|
|
||||||
get_file_url,
|
|
||||||
save_blob,
|
|
||||||
)
|
|
||||||
from .azure.creds import get_agent_instance_url, get_instance_id
|
|
||||||
from .azure.log_analytics import get_monitor_settings
|
|
||||||
from .azure.queue import get_queue_sas
|
|
||||||
from .azure.storage import StorageType
|
|
||||||
from .config import InstanceConfig
|
|
||||||
from .reports import get_report
|
|
||||||
|
|
||||||
|
|
||||||
def generic_extensions(region: Region, vm_os: OS) -> List[Extension]:
|
|
||||||
instance_config = InstanceConfig.fetch()
|
|
||||||
|
|
||||||
extensions = [monitor_extension(region, vm_os)]
|
|
||||||
|
|
||||||
dependency = dependency_extension(region, vm_os)
|
|
||||||
if dependency:
|
|
||||||
extensions.append(dependency)
|
|
||||||
|
|
||||||
if instance_config.extensions:
|
|
||||||
|
|
||||||
if instance_config.extensions.keyvault:
|
|
||||||
keyvault = keyvault_extension(
|
|
||||||
region, instance_config.extensions.keyvault, vm_os
|
|
||||||
)
|
|
||||||
extensions.append(keyvault)
|
|
||||||
|
|
||||||
if instance_config.extensions.geneva and vm_os == OS.windows:
|
|
||||||
geneva = geneva_extension(region)
|
|
||||||
extensions.append(geneva)
|
|
||||||
|
|
||||||
if instance_config.extensions.azure_monitor and vm_os == OS.linux:
|
|
||||||
azmon = azmon_extension(region, instance_config.extensions.azure_monitor)
|
|
||||||
extensions.append(azmon)
|
|
||||||
|
|
||||||
if instance_config.extensions.azure_security and vm_os == OS.linux:
|
|
||||||
azsec = azsec_extension(region)
|
|
||||||
extensions.append(azsec)
|
|
||||||
|
|
||||||
return extensions
|
|
||||||
|
|
||||||
|
|
||||||
def monitor_extension(region: Region, vm_os: OS) -> Extension:
|
|
||||||
settings = get_monitor_settings()
|
|
||||||
|
|
||||||
if vm_os == OS.windows:
|
|
||||||
return {
|
|
||||||
"name": "OMSExtension",
|
|
||||||
"publisher": "Microsoft.EnterpriseCloud.Monitoring",
|
|
||||||
"type": "MicrosoftMonitoringAgent",
|
|
||||||
"typeHandlerVersion": "1.0",
|
|
||||||
"location": region,
|
|
||||||
"autoUpgradeMinorVersion": True,
|
|
||||||
"settings": {"workspaceId": settings["id"]},
|
|
||||||
"protectedSettings": {"workspaceKey": settings["key"]},
|
|
||||||
}
|
|
||||||
elif vm_os == OS.linux:
|
|
||||||
return {
|
|
||||||
"name": "OMSExtension",
|
|
||||||
"publisher": "Microsoft.EnterpriseCloud.Monitoring",
|
|
||||||
"type": "OmsAgentForLinux",
|
|
||||||
"typeHandlerVersion": "1.12",
|
|
||||||
"location": region,
|
|
||||||
"autoUpgradeMinorVersion": True,
|
|
||||||
"settings": {"workspaceId": settings["id"]},
|
|
||||||
"protectedSettings": {"workspaceKey": settings["key"]},
|
|
||||||
}
|
|
||||||
raise NotImplementedError("unsupported os: %s" % vm_os)
|
|
||||||
|
|
||||||
|
|
||||||
def dependency_extension(region: Region, vm_os: OS) -> Optional[Extension]:
|
|
||||||
if vm_os == OS.windows:
|
|
||||||
extension = {
|
|
||||||
"name": "DependencyAgentWindows",
|
|
||||||
"publisher": "Microsoft.Azure.Monitoring.DependencyAgent",
|
|
||||||
"type": "DependencyAgentWindows",
|
|
||||||
"typeHandlerVersion": "9.5",
|
|
||||||
"location": region,
|
|
||||||
"autoUpgradeMinorVersion": True,
|
|
||||||
}
|
|
||||||
return extension
|
|
||||||
else:
|
|
||||||
# TODO: dependency agent for linux is not reliable
|
|
||||||
# extension = {
|
|
||||||
# "name": "DependencyAgentLinux",
|
|
||||||
# "publisher": "Microsoft.Azure.Monitoring.DependencyAgent",
|
|
||||||
# "type": "DependencyAgentLinux",
|
|
||||||
# "typeHandlerVersion": "9.5",
|
|
||||||
# "location": vm.region,
|
|
||||||
# "autoUpgradeMinorVersion": True,
|
|
||||||
# }
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def geneva_extension(region: Region) -> Extension:
|
|
||||||
|
|
||||||
return {
|
|
||||||
"name": "Microsoft.Azure.Geneva.GenevaMonitoring",
|
|
||||||
"publisher": "Microsoft.Azure.Geneva",
|
|
||||||
"type": "GenevaMonitoring",
|
|
||||||
"typeHandlerVersion": "2.0",
|
|
||||||
"location": region,
|
|
||||||
"autoUpgradeMinorVersion": True,
|
|
||||||
"enableAutomaticUpgrade": True,
|
|
||||||
"settings": {},
|
|
||||||
"protectedSettings": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def azmon_extension(
|
|
||||||
region: Region, azure_monitor: AzureMonitorExtensionConfig
|
|
||||||
) -> Extension:
|
|
||||||
|
|
||||||
auth_id = azure_monitor.monitoringGCSAuthId
|
|
||||||
config_version = azure_monitor.config_version
|
|
||||||
moniker = azure_monitor.moniker
|
|
||||||
namespace = azure_monitor.namespace
|
|
||||||
environment = azure_monitor.monitoringGSEnvironment
|
|
||||||
account = azure_monitor.monitoringGCSAccount
|
|
||||||
auth_id_type = azure_monitor.monitoringGCSAuthIdType
|
|
||||||
|
|
||||||
return {
|
|
||||||
"name": "AzureMonitorLinuxAgent",
|
|
||||||
"publisher": "Microsoft.Azure.Monitor",
|
|
||||||
"location": region,
|
|
||||||
"type": "AzureMonitorLinuxAgent",
|
|
||||||
"typeHandlerVersion": "1.0",
|
|
||||||
"autoUpgradeMinorVersion": True,
|
|
||||||
"settings": {"GCS_AUTO_CONFIG": True},
|
|
||||||
"protectedsettings": {
|
|
||||||
"configVersion": config_version,
|
|
||||||
"moniker": moniker,
|
|
||||||
"namespace": namespace,
|
|
||||||
"monitoringGCSEnvironment": environment,
|
|
||||||
"monitoringGCSAccount": account,
|
|
||||||
"monitoringGCSRegion": region,
|
|
||||||
"monitoringGCSAuthId": auth_id,
|
|
||||||
"monitoringGCSAuthIdType": auth_id_type,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def azsec_extension(region: Region) -> Extension:
|
|
||||||
|
|
||||||
return {
|
|
||||||
"name": "AzureSecurityLinuxAgent",
|
|
||||||
"publisher": "Microsoft.Azure.Security.Monitoring",
|
|
||||||
"location": region,
|
|
||||||
"type": "AzureSecurityLinuxAgent",
|
|
||||||
"typeHandlerVersion": "2.0",
|
|
||||||
"autoUpgradeMinorVersion": True,
|
|
||||||
"settings": {"enableGenevaUpload": True, "enableAutoConfig": True},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def keyvault_extension(
|
|
||||||
region: Region, keyvault: KeyvaultExtensionConfig, vm_os: OS
|
|
||||||
) -> Extension:
|
|
||||||
|
|
||||||
keyvault_name = keyvault.keyvault_name
|
|
||||||
cert_name = keyvault.cert_name
|
|
||||||
uri = keyvault_name + cert_name
|
|
||||||
|
|
||||||
if vm_os == OS.windows:
|
|
||||||
return {
|
|
||||||
"name": "KVVMExtensionForWindows",
|
|
||||||
"location": region,
|
|
||||||
"publisher": "Microsoft.Azure.KeyVault",
|
|
||||||
"type": "KeyVaultForWindows",
|
|
||||||
"typeHandlerVersion": "1.0",
|
|
||||||
"autoUpgradeMinorVersion": True,
|
|
||||||
"settings": {
|
|
||||||
"secretsManagementSettings": {
|
|
||||||
"pollingIntervalInS": "3600",
|
|
||||||
"certificateStoreName": "MY",
|
|
||||||
"linkOnRenewal": False,
|
|
||||||
"certificateStoreLocation": "LocalMachine",
|
|
||||||
"requireInitialSync": True,
|
|
||||||
"observedCertificates": [uri],
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
elif vm_os == OS.linux:
|
|
||||||
cert_path = keyvault.cert_path
|
|
||||||
extension_store = keyvault.extension_store
|
|
||||||
cert_location = cert_path + extension_store
|
|
||||||
return {
|
|
||||||
"name": "KVVMExtensionForLinux",
|
|
||||||
"location": region,
|
|
||||||
"publisher": "Microsoft.Azure.KeyVault",
|
|
||||||
"type": "KeyVaultForLinux",
|
|
||||||
"typeHandlerVersion": "2.0",
|
|
||||||
"autoUpgradeMinorVersion": True,
|
|
||||||
"settings": {
|
|
||||||
"secretsManagementSettings": {
|
|
||||||
"pollingIntervalInS": "3600",
|
|
||||||
"certificateStoreLocation": cert_location,
|
|
||||||
"observedCertificates": [uri],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
raise NotImplementedError("unsupported os: %s" % vm_os)
|
|
||||||
|
|
||||||
|
|
||||||
def build_scaleset_script(pool: Pool, scaleset: Scaleset) -> str:
|
|
||||||
commands = []
|
|
||||||
extension = "ps1" if pool.os == OS.windows else "sh"
|
|
||||||
filename = f"{scaleset.scaleset_id}/scaleset-setup.{extension}"
|
|
||||||
sep = "\r\n" if pool.os == OS.windows else "\n"
|
|
||||||
|
|
||||||
if pool.os == OS.windows and scaleset.auth is not None:
|
|
||||||
ssh_key = scaleset.auth.public_key.strip()
|
|
||||||
ssh_path = "$env:ProgramData/ssh/administrators_authorized_keys"
|
|
||||||
commands += [f'Set-Content -Path {ssh_path} -Value "{ssh_key}"']
|
|
||||||
|
|
||||||
save_blob(
|
|
||||||
Container("vm-scripts"), filename, sep.join(commands) + sep, StorageType.config
|
|
||||||
)
|
|
||||||
return get_file_url(Container("vm-scripts"), filename, StorageType.config)
|
|
||||||
|
|
||||||
|
|
||||||
def build_pool_config(pool: Pool) -> str:
|
|
||||||
config = AgentConfig(
|
|
||||||
pool_name=pool.name,
|
|
||||||
onefuzz_url=get_agent_instance_url(),
|
|
||||||
heartbeat_queue=get_queue_sas(
|
|
||||||
"node-heartbeat",
|
|
||||||
StorageType.config,
|
|
||||||
add=True,
|
|
||||||
),
|
|
||||||
instance_telemetry_key=os.environ.get("APPINSIGHTS_INSTRUMENTATIONKEY"),
|
|
||||||
microsoft_telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"),
|
|
||||||
instance_id=get_instance_id(),
|
|
||||||
)
|
|
||||||
|
|
||||||
multi_tenant_domain = os.environ.get("MULTI_TENANT_DOMAIN")
|
|
||||||
if multi_tenant_domain:
|
|
||||||
config.multi_tenant_domain = multi_tenant_domain
|
|
||||||
|
|
||||||
filename = f"{pool.name}/config.json"
|
|
||||||
|
|
||||||
save_blob(
|
|
||||||
Container("vm-scripts"),
|
|
||||||
filename,
|
|
||||||
config.json(),
|
|
||||||
StorageType.config,
|
|
||||||
)
|
|
||||||
|
|
||||||
return config_url(Container("vm-scripts"), filename, False)
|
|
||||||
|
|
||||||
|
|
||||||
def update_managed_scripts() -> None:
|
|
||||||
commands = [
|
|
||||||
"azcopy sync '%s' instance-specific-setup"
|
|
||||||
% (
|
|
||||||
get_container_sas_url(
|
|
||||||
Container("instance-specific-setup"),
|
|
||||||
StorageType.config,
|
|
||||||
read=True,
|
|
||||||
list_=True,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
"azcopy sync '%s' tools"
|
|
||||||
% (
|
|
||||||
get_container_sas_url(
|
|
||||||
Container("tools"), StorageType.config, read=True, list_=True
|
|
||||||
)
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
save_blob(
|
|
||||||
Container("vm-scripts"),
|
|
||||||
"managed.ps1",
|
|
||||||
"\r\n".join(commands) + "\r\n",
|
|
||||||
StorageType.config,
|
|
||||||
)
|
|
||||||
save_blob(
|
|
||||||
Container("vm-scripts"),
|
|
||||||
"managed.sh",
|
|
||||||
"\n".join(commands) + "\n",
|
|
||||||
StorageType.config,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def config_url(container: Container, filename: str, with_sas: bool) -> str:
|
|
||||||
if with_sas:
|
|
||||||
return get_file_sas_url(container, filename, StorageType.config, read=True)
|
|
||||||
else:
|
|
||||||
return get_file_url(container, filename, StorageType.config)
|
|
||||||
|
|
||||||
|
|
||||||
def agent_config(
|
|
||||||
region: Region,
|
|
||||||
vm_os: OS,
|
|
||||||
mode: AgentMode,
|
|
||||||
*,
|
|
||||||
urls: Optional[List[str]] = None,
|
|
||||||
with_sas: bool = False,
|
|
||||||
) -> Extension:
|
|
||||||
update_managed_scripts()
|
|
||||||
|
|
||||||
if urls is None:
|
|
||||||
urls = []
|
|
||||||
|
|
||||||
if vm_os == OS.windows:
|
|
||||||
urls += [
|
|
||||||
config_url(Container("vm-scripts"), "managed.ps1", with_sas),
|
|
||||||
config_url(Container("tools"), "win64/azcopy.exe", with_sas),
|
|
||||||
config_url(
|
|
||||||
Container("tools"),
|
|
||||||
"win64/setup.ps1",
|
|
||||||
with_sas,
|
|
||||||
),
|
|
||||||
config_url(
|
|
||||||
Container("tools"),
|
|
||||||
"win64/onefuzz.ps1",
|
|
||||||
with_sas,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
to_execute_cmd = (
|
|
||||||
"powershell -ExecutionPolicy Unrestricted -File win64/setup.ps1 "
|
|
||||||
"-mode %s" % (mode.name)
|
|
||||||
)
|
|
||||||
extension = {
|
|
||||||
"name": "CustomScriptExtension",
|
|
||||||
"type": "CustomScriptExtension",
|
|
||||||
"publisher": "Microsoft.Compute",
|
|
||||||
"location": region,
|
|
||||||
"force_update_tag": uuid4(),
|
|
||||||
"type_handler_version": "1.9",
|
|
||||||
"auto_upgrade_minor_version": True,
|
|
||||||
"settings": {
|
|
||||||
"commandToExecute": to_execute_cmd,
|
|
||||||
"fileUris": urls,
|
|
||||||
},
|
|
||||||
"protectedSettings": {
|
|
||||||
"managedIdentity": {},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return extension
|
|
||||||
elif vm_os == OS.linux:
|
|
||||||
urls += [
|
|
||||||
config_url(
|
|
||||||
Container("vm-scripts"),
|
|
||||||
"managed.sh",
|
|
||||||
with_sas,
|
|
||||||
),
|
|
||||||
config_url(
|
|
||||||
Container("tools"),
|
|
||||||
"linux/azcopy",
|
|
||||||
with_sas,
|
|
||||||
),
|
|
||||||
config_url(
|
|
||||||
Container("tools"),
|
|
||||||
"linux/setup.sh",
|
|
||||||
with_sas,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
to_execute_cmd = "bash setup.sh %s" % (mode.name)
|
|
||||||
|
|
||||||
extension = {
|
|
||||||
"name": "CustomScript",
|
|
||||||
"publisher": "Microsoft.Azure.Extensions",
|
|
||||||
"type": "CustomScript",
|
|
||||||
"typeHandlerVersion": "2.1",
|
|
||||||
"location": region,
|
|
||||||
"force_update_tag": uuid4(),
|
|
||||||
"autoUpgradeMinorVersion": True,
|
|
||||||
"settings": {
|
|
||||||
"commandToExecute": to_execute_cmd,
|
|
||||||
"fileUris": urls,
|
|
||||||
},
|
|
||||||
"protectedSettings": {
|
|
||||||
"managedIdentity": {},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return extension
|
|
||||||
|
|
||||||
raise NotImplementedError("unsupported OS: %s" % vm_os)
|
|
||||||
|
|
||||||
|
|
||||||
def fuzz_extensions(pool: Pool, scaleset: Scaleset) -> List[Extension]:
|
|
||||||
urls = [build_pool_config(pool), build_scaleset_script(pool, scaleset)]
|
|
||||||
fuzz_extension = agent_config(scaleset.region, pool.os, AgentMode.fuzz, urls=urls)
|
|
||||||
extensions = generic_extensions(scaleset.region, pool.os)
|
|
||||||
extensions += [fuzz_extension]
|
|
||||||
return extensions
|
|
||||||
|
|
||||||
|
|
||||||
def repro_extensions(
|
|
||||||
region: Region,
|
|
||||||
repro_os: OS,
|
|
||||||
repro_id: UUID,
|
|
||||||
repro_config: ReproConfig,
|
|
||||||
setup_container: Optional[Container],
|
|
||||||
) -> List[Extension]:
|
|
||||||
# TODO - what about contents of repro.ps1 / repro.sh?
|
|
||||||
report = get_report(repro_config.container, repro_config.path)
|
|
||||||
if report is None:
|
|
||||||
raise Exception("invalid report: %s" % repro_config)
|
|
||||||
|
|
||||||
if report.input_blob is None:
|
|
||||||
raise Exception("unable to perform reproduction without an input blob")
|
|
||||||
|
|
||||||
commands = []
|
|
||||||
if setup_container:
|
|
||||||
commands += [
|
|
||||||
"azcopy sync '%s' ./setup"
|
|
||||||
% (
|
|
||||||
get_container_sas_url(
|
|
||||||
setup_container, StorageType.corpus, read=True, list_=True
|
|
||||||
)
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
urls = [
|
|
||||||
get_file_sas_url(
|
|
||||||
repro_config.container, repro_config.path, StorageType.corpus, read=True
|
|
||||||
),
|
|
||||||
get_file_sas_url(
|
|
||||||
report.input_blob.container,
|
|
||||||
report.input_blob.name,
|
|
||||||
StorageType.corpus,
|
|
||||||
read=True,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
repro_files = []
|
|
||||||
if repro_os == OS.windows:
|
|
||||||
repro_files = ["%s/repro.ps1" % repro_id]
|
|
||||||
task_script = "\r\n".join(commands)
|
|
||||||
script_name = "task-setup.ps1"
|
|
||||||
else:
|
|
||||||
repro_files = ["%s/repro.sh" % repro_id, "%s/repro-stdout.sh" % repro_id]
|
|
||||||
commands += ["chmod -R +x setup"]
|
|
||||||
task_script = "\n".join(commands)
|
|
||||||
script_name = "task-setup.sh"
|
|
||||||
|
|
||||||
save_blob(
|
|
||||||
Container("task-configs"),
|
|
||||||
"%s/%s" % (repro_id, script_name),
|
|
||||||
task_script,
|
|
||||||
StorageType.config,
|
|
||||||
)
|
|
||||||
|
|
||||||
for repro_file in repro_files:
|
|
||||||
urls += [
|
|
||||||
get_file_sas_url(
|
|
||||||
Container("repro-scripts"),
|
|
||||||
repro_file,
|
|
||||||
StorageType.config,
|
|
||||||
read=True,
|
|
||||||
),
|
|
||||||
get_file_sas_url(
|
|
||||||
Container("task-configs"),
|
|
||||||
"%s/%s" % (repro_id, script_name),
|
|
||||||
StorageType.config,
|
|
||||||
read=True,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
base_extension = agent_config(
|
|
||||||
region, repro_os, AgentMode.repro, urls=urls, with_sas=True
|
|
||||||
)
|
|
||||||
extensions = generic_extensions(region, repro_os)
|
|
||||||
extensions += [base_extension]
|
|
||||||
return extensions
|
|
||||||
|
|
||||||
|
|
||||||
def proxy_manager_extensions(region: Region, proxy_id: UUID) -> List[Extension]:
|
|
||||||
urls = [
|
|
||||||
get_file_sas_url(
|
|
||||||
Container("proxy-configs"),
|
|
||||||
"%s/%s/config.json" % (region, proxy_id),
|
|
||||||
StorageType.config,
|
|
||||||
read=True,
|
|
||||||
),
|
|
||||||
get_file_sas_url(
|
|
||||||
Container("tools"),
|
|
||||||
"linux/onefuzz-proxy-manager",
|
|
||||||
StorageType.config,
|
|
||||||
read=True,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
base_extension = agent_config(
|
|
||||||
region, OS.linux, AgentMode.proxy, urls=urls, with_sas=True
|
|
||||||
)
|
|
||||||
extensions = generic_extensions(region, OS.linux)
|
|
||||||
extensions += [base_extension]
|
|
||||||
return extensions
|
|
@ -1,14 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
from .afl import afl_linux, afl_windows
|
|
||||||
from .libfuzzer import libfuzzer_linux, libfuzzer_windows
|
|
||||||
|
|
||||||
TEMPLATES = {
|
|
||||||
"afl_windows": afl_windows,
|
|
||||||
"afl_linux": afl_linux,
|
|
||||||
"libfuzzer_linux": libfuzzer_linux,
|
|
||||||
"libfuzzer_windows": libfuzzer_windows,
|
|
||||||
}
|
|
@ -1,271 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from onefuzztypes.enums import (
|
|
||||||
OS,
|
|
||||||
ContainerType,
|
|
||||||
TaskType,
|
|
||||||
UserFieldOperation,
|
|
||||||
UserFieldType,
|
|
||||||
)
|
|
||||||
from onefuzztypes.job_templates import JobTemplate, UserField, UserFieldLocation
|
|
||||||
from onefuzztypes.models import (
|
|
||||||
JobConfig,
|
|
||||||
TaskConfig,
|
|
||||||
TaskContainers,
|
|
||||||
TaskDetails,
|
|
||||||
TaskPool,
|
|
||||||
)
|
|
||||||
from onefuzztypes.primitives import Container, PoolName
|
|
||||||
|
|
||||||
from .common import (
|
|
||||||
DURATION_HELP,
|
|
||||||
POOL_HELP,
|
|
||||||
REBOOT_HELP,
|
|
||||||
RETRY_COUNT_HELP,
|
|
||||||
TAGS_HELP,
|
|
||||||
TARGET_EXE_HELP,
|
|
||||||
TARGET_OPTIONS_HELP,
|
|
||||||
VM_COUNT_HELP,
|
|
||||||
)
|
|
||||||
|
|
||||||
afl_linux = JobTemplate(
|
|
||||||
os=OS.linux,
|
|
||||||
job=JobConfig(project="", name=Container(""), build="", duration=1),
|
|
||||||
tasks=[
|
|
||||||
TaskConfig(
|
|
||||||
job_id=(UUID(int=0)),
|
|
||||||
task=TaskDetails(
|
|
||||||
type=TaskType.generic_supervisor,
|
|
||||||
duration=1,
|
|
||||||
target_exe="fuzz.exe",
|
|
||||||
target_env={},
|
|
||||||
target_options=[],
|
|
||||||
supervisor_exe="",
|
|
||||||
supervisor_options=[],
|
|
||||||
supervisor_input_marker="@@",
|
|
||||||
),
|
|
||||||
pool=TaskPool(count=1, pool_name=PoolName("")),
|
|
||||||
containers=[
|
|
||||||
TaskContainers(
|
|
||||||
name=Container("afl-container-name"), type=ContainerType.tools
|
|
||||||
),
|
|
||||||
TaskContainers(name=Container(""), type=ContainerType.setup),
|
|
||||||
TaskContainers(name=Container(""), type=ContainerType.crashes),
|
|
||||||
TaskContainers(name=Container(""), type=ContainerType.inputs),
|
|
||||||
],
|
|
||||||
tags={},
|
|
||||||
),
|
|
||||||
TaskConfig(
|
|
||||||
job_id=UUID(int=0),
|
|
||||||
prereq_tasks=[UUID(int=0)],
|
|
||||||
task=TaskDetails(
|
|
||||||
type=TaskType.generic_crash_report,
|
|
||||||
duration=1,
|
|
||||||
target_exe="fuzz.exe",
|
|
||||||
target_env={},
|
|
||||||
target_options=[],
|
|
||||||
check_debugger=True,
|
|
||||||
),
|
|
||||||
pool=TaskPool(count=1, pool_name=PoolName("")),
|
|
||||||
containers=[
|
|
||||||
TaskContainers(name=Container(""), type=ContainerType.setup),
|
|
||||||
TaskContainers(name=Container(""), type=ContainerType.crashes),
|
|
||||||
TaskContainers(name=Container(""), type=ContainerType.no_repro),
|
|
||||||
TaskContainers(name=Container(""), type=ContainerType.reports),
|
|
||||||
TaskContainers(name=Container(""), type=ContainerType.unique_reports),
|
|
||||||
],
|
|
||||||
tags={},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
notifications=[],
|
|
||||||
user_fields=[
|
|
||||||
UserField(
|
|
||||||
name="pool_name",
|
|
||||||
help=POOL_HELP,
|
|
||||||
type=UserFieldType.Str,
|
|
||||||
required=True,
|
|
||||||
locations=[
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/0/pool/pool_name",
|
|
||||||
),
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/1/pool/pool_name",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
UserField(
|
|
||||||
name="duration",
|
|
||||||
help=DURATION_HELP,
|
|
||||||
type=UserFieldType.Int,
|
|
||||||
default=24,
|
|
||||||
locations=[
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/0/task/duration",
|
|
||||||
),
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/1/task/duration",
|
|
||||||
),
|
|
||||||
UserFieldLocation(op=UserFieldOperation.replace, path="/job/duration"),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
UserField(
|
|
||||||
name="target_exe",
|
|
||||||
help=TARGET_EXE_HELP,
|
|
||||||
type=UserFieldType.Str,
|
|
||||||
default="fuzz.exe",
|
|
||||||
locations=[
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/0/task/target_exe",
|
|
||||||
),
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/1/task/target_exe",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
UserField(
|
|
||||||
name="target_options",
|
|
||||||
help=TARGET_OPTIONS_HELP,
|
|
||||||
type=UserFieldType.ListStr,
|
|
||||||
locations=[
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/0/task/target_options",
|
|
||||||
),
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/1/task/target_options",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
UserField(
|
|
||||||
name="supervisor_exe",
|
|
||||||
help="Path to the AFL executable",
|
|
||||||
type=UserFieldType.Str,
|
|
||||||
default="{tools_dir}/afl-fuzz",
|
|
||||||
locations=[
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/0/task/supervisor_exe",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
UserField(
|
|
||||||
name="supervisor_options",
|
|
||||||
help="AFL command line options",
|
|
||||||
type=UserFieldType.ListStr,
|
|
||||||
default=[
|
|
||||||
"-d",
|
|
||||||
"-i",
|
|
||||||
"{input_corpus}",
|
|
||||||
"-o",
|
|
||||||
"{runtime_dir}",
|
|
||||||
"--",
|
|
||||||
"{target_exe}",
|
|
||||||
"{target_options}",
|
|
||||||
],
|
|
||||||
locations=[
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/0/task/supervisor_options",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
UserField(
|
|
||||||
name="supervisor_env",
|
|
||||||
help="Enviornment variables for AFL",
|
|
||||||
type=UserFieldType.DictStr,
|
|
||||||
locations=[
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/0/task/supervisor_env",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
UserField(
|
|
||||||
name="vm_count",
|
|
||||||
help=VM_COUNT_HELP,
|
|
||||||
type=UserFieldType.Int,
|
|
||||||
default=2,
|
|
||||||
locations=[
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/0/pool/count",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
UserField(
|
|
||||||
name="check_retry_count",
|
|
||||||
help=RETRY_COUNT_HELP,
|
|
||||||
type=UserFieldType.Int,
|
|
||||||
locations=[
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/1/task/check_retry_count",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
UserField(
|
|
||||||
name="afl_container",
|
|
||||||
help=(
|
|
||||||
"Name of the AFL storage container (use "
|
|
||||||
"this to specify alternate builds of AFL)"
|
|
||||||
),
|
|
||||||
type=UserFieldType.Str,
|
|
||||||
default="afl-linux",
|
|
||||||
locations=[
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/0/containers/0/name",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
UserField(
|
|
||||||
name="reboot_after_setup",
|
|
||||||
help=REBOOT_HELP,
|
|
||||||
type=UserFieldType.Bool,
|
|
||||||
default=False,
|
|
||||||
locations=[
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/0/task/reboot_after_setup",
|
|
||||||
),
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/1/task/reboot_after_setup",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
UserField(
|
|
||||||
name="tags",
|
|
||||||
help=TAGS_HELP,
|
|
||||||
type=UserFieldType.DictStr,
|
|
||||||
locations=[
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.add,
|
|
||||||
path="/tasks/0/tags",
|
|
||||||
),
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.add,
|
|
||||||
path="/tasks/1/tags",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
afl_windows = afl_linux.copy(deep=True)
|
|
||||||
afl_windows.os = OS.windows
|
|
||||||
for user_field in afl_windows.user_fields:
|
|
||||||
if user_field.name == "afl_container":
|
|
||||||
user_field.default = "afl-windows"
|
|
@ -1,14 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
|
|
||||||
POOL_HELP = "Execute the task on the specified pool"
|
|
||||||
DURATION_HELP = "Number of hours to execute the task"
|
|
||||||
TARGET_EXE_HELP = "Path to the target executable"
|
|
||||||
TARGET_OPTIONS_HELP = "Command line options for the target"
|
|
||||||
VM_COUNT_HELP = "Number of VMs to use for fuzzing"
|
|
||||||
RETRY_COUNT_HELP = "Number of times to retry a crash to verify reproducability"
|
|
||||||
REBOOT_HELP = "After executing the setup script, reboot the VM"
|
|
||||||
TAGS_HELP = "User provided metadata for the tasks"
|
|
@ -1,348 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from onefuzztypes.enums import (
|
|
||||||
OS,
|
|
||||||
ContainerType,
|
|
||||||
TaskType,
|
|
||||||
UserFieldOperation,
|
|
||||||
UserFieldType,
|
|
||||||
)
|
|
||||||
from onefuzztypes.job_templates import JobTemplate, UserField, UserFieldLocation
|
|
||||||
from onefuzztypes.models import (
|
|
||||||
JobConfig,
|
|
||||||
TaskConfig,
|
|
||||||
TaskContainers,
|
|
||||||
TaskDetails,
|
|
||||||
TaskPool,
|
|
||||||
)
|
|
||||||
from onefuzztypes.primitives import Container, PoolName
|
|
||||||
|
|
||||||
from .common import (
|
|
||||||
DURATION_HELP,
|
|
||||||
POOL_HELP,
|
|
||||||
REBOOT_HELP,
|
|
||||||
RETRY_COUNT_HELP,
|
|
||||||
TAGS_HELP,
|
|
||||||
TARGET_EXE_HELP,
|
|
||||||
TARGET_OPTIONS_HELP,
|
|
||||||
VM_COUNT_HELP,
|
|
||||||
)
|
|
||||||
|
|
||||||
libfuzzer_linux = JobTemplate(
|
|
||||||
os=OS.linux,
|
|
||||||
job=JobConfig(project="", name=Container(""), build="", duration=1),
|
|
||||||
tasks=[
|
|
||||||
TaskConfig(
|
|
||||||
job_id=UUID(int=0),
|
|
||||||
task=TaskDetails(
|
|
||||||
type=TaskType.libfuzzer_fuzz,
|
|
||||||
duration=1,
|
|
||||||
target_exe="fuzz.exe",
|
|
||||||
target_env={},
|
|
||||||
target_options=[],
|
|
||||||
),
|
|
||||||
pool=TaskPool(count=1, pool_name=PoolName("")),
|
|
||||||
containers=[
|
|
||||||
TaskContainers(name=Container(""), type=ContainerType.setup),
|
|
||||||
TaskContainers(name=Container(""), type=ContainerType.crashes),
|
|
||||||
TaskContainers(name=Container(""), type=ContainerType.inputs),
|
|
||||||
],
|
|
||||||
tags={},
|
|
||||||
),
|
|
||||||
TaskConfig(
|
|
||||||
job_id=UUID(int=0),
|
|
||||||
prereq_tasks=[UUID(int=0)],
|
|
||||||
task=TaskDetails(
|
|
||||||
type=TaskType.libfuzzer_crash_report,
|
|
||||||
duration=1,
|
|
||||||
target_exe="fuzz.exe",
|
|
||||||
target_env={},
|
|
||||||
target_options=[],
|
|
||||||
),
|
|
||||||
pool=TaskPool(count=1, pool_name=PoolName("")),
|
|
||||||
containers=[
|
|
||||||
TaskContainers(name=Container(""), type=ContainerType.setup),
|
|
||||||
TaskContainers(name=Container(""), type=ContainerType.crashes),
|
|
||||||
TaskContainers(name=Container(""), type=ContainerType.no_repro),
|
|
||||||
TaskContainers(name=Container(""), type=ContainerType.reports),
|
|
||||||
TaskContainers(name=Container(""), type=ContainerType.unique_reports),
|
|
||||||
],
|
|
||||||
tags={},
|
|
||||||
),
|
|
||||||
TaskConfig(
|
|
||||||
job_id=UUID(int=0),
|
|
||||||
prereq_tasks=[UUID(int=0)],
|
|
||||||
task=TaskDetails(
|
|
||||||
type=TaskType.coverage,
|
|
||||||
duration=1,
|
|
||||||
target_exe="fuzz.exe",
|
|
||||||
target_env={},
|
|
||||||
target_options=[],
|
|
||||||
),
|
|
||||||
pool=TaskPool(count=1, pool_name=PoolName("")),
|
|
||||||
containers=[
|
|
||||||
TaskContainers(name=Container(""), type=ContainerType.setup),
|
|
||||||
TaskContainers(name=Container(""), type=ContainerType.readonly_inputs),
|
|
||||||
TaskContainers(name=Container(""), type=ContainerType.coverage),
|
|
||||||
],
|
|
||||||
tags={},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
notifications=[],
|
|
||||||
user_fields=[
|
|
||||||
UserField(
|
|
||||||
name="pool_name",
|
|
||||||
help=POOL_HELP,
|
|
||||||
type=UserFieldType.Str,
|
|
||||||
required=True,
|
|
||||||
locations=[
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/0/pool/pool_name",
|
|
||||||
),
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/1/pool/pool_name",
|
|
||||||
),
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/2/pool/pool_name",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
UserField(
|
|
||||||
name="target_exe",
|
|
||||||
help=TARGET_EXE_HELP,
|
|
||||||
type=UserFieldType.Str,
|
|
||||||
default="fuzz.exe",
|
|
||||||
locations=[
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/0/task/target_exe",
|
|
||||||
),
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/1/task/target_exe",
|
|
||||||
),
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/2/task/target_exe",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
UserField(
|
|
||||||
name="duration",
|
|
||||||
help=DURATION_HELP,
|
|
||||||
type=UserFieldType.Int,
|
|
||||||
default=24,
|
|
||||||
locations=[
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/0/task/duration",
|
|
||||||
),
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/1/task/duration",
|
|
||||||
),
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/2/task/duration",
|
|
||||||
),
|
|
||||||
UserFieldLocation(op=UserFieldOperation.replace, path="/job/duration"),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
UserField(
|
|
||||||
name="target_workers",
|
|
||||||
help="Number of instances of the libfuzzer target on each VM",
|
|
||||||
type=UserFieldType.Int,
|
|
||||||
locations=[
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/0/task/target_workers",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
UserField(
|
|
||||||
name="vm_count",
|
|
||||||
help=VM_COUNT_HELP,
|
|
||||||
type=UserFieldType.Int,
|
|
||||||
default=2,
|
|
||||||
locations=[
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/0/pool/count",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
UserField(
|
|
||||||
name="target_options",
|
|
||||||
help=TARGET_OPTIONS_HELP,
|
|
||||||
type=UserFieldType.ListStr,
|
|
||||||
locations=[
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/0/task/target_options",
|
|
||||||
),
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/1/task/target_options",
|
|
||||||
),
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/2/task/target_options",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
UserField(
|
|
||||||
name="target_env",
|
|
||||||
help="Environment variables for the target",
|
|
||||||
type=UserFieldType.DictStr,
|
|
||||||
locations=[
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/0/task/target_env",
|
|
||||||
),
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/1/task/target_env",
|
|
||||||
),
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/2/task/target_env",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
UserField(
|
|
||||||
name="check_fuzzer_help",
|
|
||||||
help="Verify fuzzer by checking if it supports -help=1",
|
|
||||||
type=UserFieldType.Bool,
|
|
||||||
default=True,
|
|
||||||
locations=[
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.add,
|
|
||||||
path="/tasks/0/task/check_fuzzer_help",
|
|
||||||
),
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.add,
|
|
||||||
path="/tasks/1/task/check_fuzzer_help",
|
|
||||||
),
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.add,
|
|
||||||
path="/tasks/2/task/check_fuzzer_help",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
UserField(
|
|
||||||
name="colocate",
|
|
||||||
help="Run all of the tasks on the same node",
|
|
||||||
type=UserFieldType.Bool,
|
|
||||||
default=True,
|
|
||||||
locations=[
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.add,
|
|
||||||
path="/tasks/0/colocate",
|
|
||||||
),
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.add,
|
|
||||||
path="/tasks/1/colocate",
|
|
||||||
),
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.add,
|
|
||||||
path="/tasks/2/colocate",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
UserField(
|
|
||||||
name="expect_crash_on_failure",
|
|
||||||
help="Require crashes upon non-zero exits from libfuzzer",
|
|
||||||
type=UserFieldType.Bool,
|
|
||||||
default=False,
|
|
||||||
locations=[
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.add,
|
|
||||||
path="/tasks/0/task/expect_crash_on_failure",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
UserField(
|
|
||||||
name="reboot_after_setup",
|
|
||||||
help=REBOOT_HELP,
|
|
||||||
type=UserFieldType.Bool,
|
|
||||||
default=False,
|
|
||||||
locations=[
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/0/task/reboot_after_setup",
|
|
||||||
),
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/1/task/reboot_after_setup",
|
|
||||||
),
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/2/task/reboot_after_setup",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
UserField(
|
|
||||||
name="check_retry_count",
|
|
||||||
help=RETRY_COUNT_HELP,
|
|
||||||
type=UserFieldType.Int,
|
|
||||||
locations=[
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/1/task/check_retry_count",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
UserField(
|
|
||||||
name="target_timeout",
|
|
||||||
help="Number of seconds to timeout during reproduction",
|
|
||||||
type=UserFieldType.Int,
|
|
||||||
locations=[
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/1/task/target_timeout",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
UserField(
|
|
||||||
name="minimized_stack_depth",
|
|
||||||
help="Number of frames to include in the minimized stack",
|
|
||||||
type=UserFieldType.Int,
|
|
||||||
locations=[
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.replace,
|
|
||||||
path="/tasks/1/task/minimized_stack_depth",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
UserField(
|
|
||||||
name="tags",
|
|
||||||
help=TAGS_HELP,
|
|
||||||
type=UserFieldType.DictStr,
|
|
||||||
locations=[
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.add,
|
|
||||||
path="/tasks/0/tags",
|
|
||||||
),
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.add,
|
|
||||||
path="/tasks/1/tags",
|
|
||||||
),
|
|
||||||
UserFieldLocation(
|
|
||||||
op=UserFieldOperation.add,
|
|
||||||
path="/tasks/2/tags",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
libfuzzer_windows = libfuzzer_linux.copy(deep=True)
|
|
||||||
libfuzzer_windows.os = OS.windows
|
|
@ -1,131 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
from jsonpatch import apply_patch
|
|
||||||
from memoization import cached
|
|
||||||
from onefuzztypes.enums import ContainerType, ErrorCode, UserFieldType
|
|
||||||
from onefuzztypes.job_templates import (
|
|
||||||
TEMPLATE_BASE_FIELDS,
|
|
||||||
JobTemplate,
|
|
||||||
JobTemplateConfig,
|
|
||||||
JobTemplateField,
|
|
||||||
JobTemplateRequest,
|
|
||||||
TemplateUserData,
|
|
||||||
UserField,
|
|
||||||
)
|
|
||||||
from onefuzztypes.models import Error, Result
|
|
||||||
|
|
||||||
|
|
||||||
def template_container_types(template: JobTemplate) -> List[ContainerType]:
|
|
||||||
return list(set(c.type for t in template.tasks for c in t.containers if not c.name))
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
|
||||||
def build_input_config(name: str, template: JobTemplate) -> JobTemplateConfig:
|
|
||||||
user_fields = [
|
|
||||||
JobTemplateField(
|
|
||||||
name=x.name,
|
|
||||||
type=x.type,
|
|
||||||
required=x.required,
|
|
||||||
default=x.default,
|
|
||||||
help=x.help,
|
|
||||||
)
|
|
||||||
for x in TEMPLATE_BASE_FIELDS + template.user_fields
|
|
||||||
]
|
|
||||||
containers = template_container_types(template)
|
|
||||||
|
|
||||||
return JobTemplateConfig(
|
|
||||||
os=template.os,
|
|
||||||
name=name,
|
|
||||||
user_fields=user_fields,
|
|
||||||
containers=containers,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_patches(
|
|
||||||
data: TemplateUserData, field: UserField
|
|
||||||
) -> List[Dict[str, TemplateUserData]]:
|
|
||||||
patches = []
|
|
||||||
|
|
||||||
if field.type == UserFieldType.Bool and not isinstance(data, bool):
|
|
||||||
raise Exception("invalid bool field")
|
|
||||||
if field.type == UserFieldType.Int and not isinstance(data, int):
|
|
||||||
raise Exception("invalid int field")
|
|
||||||
if field.type == UserFieldType.Str and not isinstance(data, str):
|
|
||||||
raise Exception("invalid str field")
|
|
||||||
if field.type == UserFieldType.DictStr and not isinstance(data, dict):
|
|
||||||
raise Exception("invalid DictStr field")
|
|
||||||
if field.type == UserFieldType.ListStr and not isinstance(data, list):
|
|
||||||
raise Exception("invalid ListStr field")
|
|
||||||
|
|
||||||
for location in field.locations:
|
|
||||||
patches.append(
|
|
||||||
{
|
|
||||||
"op": location.op.name,
|
|
||||||
"path": location.path,
|
|
||||||
"value": data,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return patches
|
|
||||||
|
|
||||||
|
|
||||||
def _fail(why: str) -> Error:
|
|
||||||
return Error(code=ErrorCode.INVALID_REQUEST, errors=[why])
|
|
||||||
|
|
||||||
|
|
||||||
def render(request: JobTemplateRequest, template: JobTemplate) -> Result[JobTemplate]:
|
|
||||||
patches = []
|
|
||||||
seen = set()
|
|
||||||
|
|
||||||
for name in request.user_fields:
|
|
||||||
for field in TEMPLATE_BASE_FIELDS + template.user_fields:
|
|
||||||
if field.name == name:
|
|
||||||
if name in seen:
|
|
||||||
return _fail(f"duplicate specification: {name}")
|
|
||||||
|
|
||||||
seen.add(name)
|
|
||||||
|
|
||||||
if name not in seen:
|
|
||||||
return _fail(f"extra field: {name}")
|
|
||||||
|
|
||||||
for field in TEMPLATE_BASE_FIELDS + template.user_fields:
|
|
||||||
if field.name not in request.user_fields:
|
|
||||||
if field.required:
|
|
||||||
return _fail(f"missing required field: {field.name}")
|
|
||||||
else:
|
|
||||||
# optional fields can be missing
|
|
||||||
continue
|
|
||||||
|
|
||||||
patches += build_patches(request.user_fields[field.name], field)
|
|
||||||
|
|
||||||
raw = json.loads(template.json())
|
|
||||||
updated = apply_patch(raw, patches)
|
|
||||||
rendered = JobTemplate.parse_obj(updated)
|
|
||||||
|
|
||||||
used_containers = []
|
|
||||||
for task in rendered.tasks:
|
|
||||||
for task_container in task.containers:
|
|
||||||
if task_container.name:
|
|
||||||
# only need to fill out containers with names
|
|
||||||
continue
|
|
||||||
|
|
||||||
for entry in request.containers:
|
|
||||||
if entry.type != task_container.type:
|
|
||||||
continue
|
|
||||||
task_container.name = entry.name
|
|
||||||
used_containers.append(entry)
|
|
||||||
|
|
||||||
if not task_container.name:
|
|
||||||
return _fail(f"missing container definition {task_container.type}")
|
|
||||||
|
|
||||||
for entry in request.containers:
|
|
||||||
if entry not in used_containers:
|
|
||||||
return _fail(f"unused container in request: {entry}")
|
|
||||||
|
|
||||||
return rendered
|
|
@ -1,118 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
from onefuzztypes.enums import ErrorCode
|
|
||||||
from onefuzztypes.job_templates import JobTemplateConfig
|
|
||||||
from onefuzztypes.job_templates import JobTemplateIndex as BASE_INDEX
|
|
||||||
from onefuzztypes.job_templates import JobTemplateRequest
|
|
||||||
from onefuzztypes.models import Error, Result, UserInfo
|
|
||||||
|
|
||||||
from ..jobs import Job
|
|
||||||
from ..notifications.main import Notification
|
|
||||||
from ..orm import ORMMixin
|
|
||||||
from ..tasks.config import TaskConfigError, check_config
|
|
||||||
from ..tasks.main import Task
|
|
||||||
from .defaults import TEMPLATES
|
|
||||||
from .render import build_input_config, render
|
|
||||||
|
|
||||||
|
|
||||||
class JobTemplateIndex(BASE_INDEX, ORMMixin):
|
|
||||||
@classmethod
|
|
||||||
def key_fields(cls) -> Tuple[str, Optional[str]]:
|
|
||||||
return ("name", None)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_base_entry(cls, name: str) -> Optional[BASE_INDEX]:
|
|
||||||
result = cls.get(name)
|
|
||||||
if result is not None:
|
|
||||||
return BASE_INDEX(name=name, template=result.template)
|
|
||||||
|
|
||||||
template = TEMPLATES.get(name)
|
|
||||||
if template is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return BASE_INDEX(name=name, template=template)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_index(cls) -> List[BASE_INDEX]:
|
|
||||||
entries = [BASE_INDEX(name=x.name, template=x.template) for x in cls.search()]
|
|
||||||
|
|
||||||
# if the local install has replaced the built-in templates, skip over them
|
|
||||||
for name, template in TEMPLATES.items():
|
|
||||||
if any(x.name == name for x in entries):
|
|
||||||
continue
|
|
||||||
entries.append(BASE_INDEX(name=name, template=template))
|
|
||||||
|
|
||||||
return entries
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_configs(cls) -> List[JobTemplateConfig]:
|
|
||||||
configs = [build_input_config(x.name, x.template) for x in cls.get_index()]
|
|
||||||
|
|
||||||
return configs
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def execute(cls, request: JobTemplateRequest, user_info: UserInfo) -> Result[Job]:
|
|
||||||
index = cls.get(request.name)
|
|
||||||
if index is None:
|
|
||||||
if request.name not in TEMPLATES:
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.INVALID_REQUEST,
|
|
||||||
errors=["no such template: %s" % request.name],
|
|
||||||
)
|
|
||||||
base_template = TEMPLATES[request.name]
|
|
||||||
else:
|
|
||||||
base_template = index.template
|
|
||||||
|
|
||||||
template = render(request, base_template)
|
|
||||||
if isinstance(template, Error):
|
|
||||||
return template
|
|
||||||
|
|
||||||
try:
|
|
||||||
for task_config in template.tasks:
|
|
||||||
check_config(task_config)
|
|
||||||
if task_config.pool is None:
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.INVALID_REQUEST, errors=["pool not defined"]
|
|
||||||
)
|
|
||||||
|
|
||||||
except TaskConfigError as err:
|
|
||||||
return Error(code=ErrorCode.INVALID_REQUEST, errors=[str(err)])
|
|
||||||
|
|
||||||
for notification_config in template.notifications:
|
|
||||||
for task_container in request.containers:
|
|
||||||
if task_container.type == notification_config.container_type:
|
|
||||||
notification = Notification.create(
|
|
||||||
task_container.name,
|
|
||||||
notification_config.notification.config,
|
|
||||||
True,
|
|
||||||
)
|
|
||||||
if isinstance(notification, Error):
|
|
||||||
return notification
|
|
||||||
|
|
||||||
job = Job(config=template.job)
|
|
||||||
job.save()
|
|
||||||
|
|
||||||
tasks: List[Task] = []
|
|
||||||
for task_config in template.tasks:
|
|
||||||
task_config.job_id = job.job_id
|
|
||||||
if task_config.prereq_tasks:
|
|
||||||
# pydantic verifies prereq_tasks in u128 form are index refs to
|
|
||||||
# previously generated tasks
|
|
||||||
task_config.prereq_tasks = [
|
|
||||||
tasks[x.int].task_id for x in task_config.prereq_tasks
|
|
||||||
]
|
|
||||||
|
|
||||||
task = Task.create(
|
|
||||||
config=task_config, job_id=job.job_id, user_info=user_info
|
|
||||||
)
|
|
||||||
if isinstance(task, Error):
|
|
||||||
return task
|
|
||||||
|
|
||||||
tasks.append(task)
|
|
||||||
|
|
||||||
return job
|
|
@ -1,144 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
from onefuzztypes.enums import ErrorCode, JobState, TaskState
|
|
||||||
from onefuzztypes.events import EventJobCreated, EventJobStopped, JobTaskStopped
|
|
||||||
from onefuzztypes.models import Error
|
|
||||||
from onefuzztypes.models import Job as BASE_JOB
|
|
||||||
|
|
||||||
from .events import send_event
|
|
||||||
from .orm import MappingIntStrAny, ORMMixin, QueryFilter
|
|
||||||
from .tasks.main import Task
|
|
||||||
|
|
||||||
JOB_LOG_PREFIX = "jobs: "
|
|
||||||
JOB_NEVER_STARTED_DURATION: timedelta = timedelta(days=30)
|
|
||||||
|
|
||||||
|
|
||||||
class Job(BASE_JOB, ORMMixin):
|
|
||||||
@classmethod
|
|
||||||
def key_fields(cls) -> Tuple[str, Optional[str]]:
|
|
||||||
return ("job_id", None)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def search_states(cls, *, states: Optional[List[JobState]] = None) -> List["Job"]:
|
|
||||||
query: QueryFilter = {}
|
|
||||||
if states:
|
|
||||||
query["state"] = states
|
|
||||||
return cls.search(query=query)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def search_expired(cls) -> List["Job"]:
|
|
||||||
time_filter = "end_time lt datetime'%s'" % datetime.utcnow().isoformat()
|
|
||||||
|
|
||||||
return cls.search(
|
|
||||||
query={"state": JobState.available()}, raw_unchecked_filter=time_filter
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def stop_never_started_jobs(cls) -> None:
|
|
||||||
# Note, the "not(end_time...)" with end_time set long before the use of
|
|
||||||
# OneFuzz enables identifying those without end_time being set.
|
|
||||||
last_timestamp = (datetime.utcnow() - JOB_NEVER_STARTED_DURATION).isoformat()
|
|
||||||
|
|
||||||
time_filter = (
|
|
||||||
f"Timestamp lt datetime'{last_timestamp}' and "
|
|
||||||
"not(end_time ge datetime'2000-01-11T00:00:00.0Z')"
|
|
||||||
)
|
|
||||||
|
|
||||||
for job in cls.search(
|
|
||||||
query={
|
|
||||||
"state": [JobState.enabled],
|
|
||||||
},
|
|
||||||
raw_unchecked_filter=time_filter,
|
|
||||||
):
|
|
||||||
for task in Task.search(query={"job_id": [job.job_id]}):
|
|
||||||
task.mark_failed(
|
|
||||||
Error(
|
|
||||||
code=ErrorCode.TASK_FAILED,
|
|
||||||
errors=["job never not start"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
JOB_LOG_PREFIX + "stopping job that never started: %s", job.job_id
|
|
||||||
)
|
|
||||||
job.stopping()
|
|
||||||
|
|
||||||
def save_exclude(self) -> Optional[MappingIntStrAny]:
|
|
||||||
return {"task_info": ...}
|
|
||||||
|
|
||||||
def telemetry_include(self) -> Optional[MappingIntStrAny]:
|
|
||||||
return {
|
|
||||||
"machine_id": ...,
|
|
||||||
"state": ...,
|
|
||||||
"scaleset_id": ...,
|
|
||||||
}
|
|
||||||
|
|
||||||
def init(self) -> None:
|
|
||||||
logging.info(JOB_LOG_PREFIX + "init: %s", self.job_id)
|
|
||||||
self.state = JobState.enabled
|
|
||||||
self.save()
|
|
||||||
|
|
||||||
def stop_if_all_done(self) -> None:
|
|
||||||
not_stopped = [
|
|
||||||
task
|
|
||||||
for task in Task.search(query={"job_id": [self.job_id]})
|
|
||||||
if task.state != TaskState.stopped
|
|
||||||
]
|
|
||||||
if not_stopped:
|
|
||||||
return
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
JOB_LOG_PREFIX + "stopping job as all tasks are stopped: %s", self.job_id
|
|
||||||
)
|
|
||||||
self.stopping()
|
|
||||||
|
|
||||||
def stopping(self) -> None:
|
|
||||||
self.state = JobState.stopping
|
|
||||||
logging.info(JOB_LOG_PREFIX + "stopping: %s", self.job_id)
|
|
||||||
tasks = Task.search(query={"job_id": [self.job_id]})
|
|
||||||
not_stopped = [task for task in tasks if task.state != TaskState.stopped]
|
|
||||||
|
|
||||||
if not_stopped:
|
|
||||||
for task in not_stopped:
|
|
||||||
task.mark_stopping()
|
|
||||||
else:
|
|
||||||
self.state = JobState.stopped
|
|
||||||
task_info = [
|
|
||||||
JobTaskStopped(
|
|
||||||
task_id=x.task_id, error=x.error, task_type=x.config.task.type
|
|
||||||
)
|
|
||||||
for x in tasks
|
|
||||||
]
|
|
||||||
send_event(
|
|
||||||
EventJobStopped(
|
|
||||||
job_id=self.job_id,
|
|
||||||
config=self.config,
|
|
||||||
user_info=self.user_info,
|
|
||||||
task_info=task_info,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.save()
|
|
||||||
|
|
||||||
def on_start(self) -> None:
|
|
||||||
# try to keep this effectively idempotent
|
|
||||||
if self.end_time is None:
|
|
||||||
self.end_time = datetime.utcnow() + timedelta(hours=self.config.duration)
|
|
||||||
self.save()
|
|
||||||
|
|
||||||
def save(self, new: bool = False, require_etag: bool = False) -> None:
|
|
||||||
created = self.etag is None
|
|
||||||
super().save(new=new, require_etag=require_etag)
|
|
||||||
|
|
||||||
if created:
|
|
||||||
send_event(
|
|
||||||
EventJobCreated(
|
|
||||||
job_id=self.job_id, config=self.config, user_info=self.user_info
|
|
||||||
)
|
|
||||||
)
|
|
@ -1,294 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Iterator, List, Optional, Tuple, Union
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from azure.devops.connection import Connection
|
|
||||||
from azure.devops.credentials import BasicAuthentication
|
|
||||||
from azure.devops.exceptions import (
|
|
||||||
AzureDevOpsAuthenticationError,
|
|
||||||
AzureDevOpsClientError,
|
|
||||||
AzureDevOpsClientRequestError,
|
|
||||||
AzureDevOpsServiceError,
|
|
||||||
)
|
|
||||||
from azure.devops.v6_0.work_item_tracking.models import (
|
|
||||||
CommentCreate,
|
|
||||||
JsonPatchOperation,
|
|
||||||
Wiql,
|
|
||||||
WorkItem,
|
|
||||||
)
|
|
||||||
from azure.devops.v6_0.work_item_tracking.work_item_tracking_client import (
|
|
||||||
WorkItemTrackingClient,
|
|
||||||
)
|
|
||||||
from memoization import cached
|
|
||||||
from onefuzztypes.models import ADOTemplate, RegressionReport, Report
|
|
||||||
from onefuzztypes.primitives import Container
|
|
||||||
|
|
||||||
from ..secrets import get_secret_string_value
|
|
||||||
from .common import Render, log_failed_notification
|
|
||||||
|
|
||||||
|
|
||||||
class AdoNotificationException(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl=60)
|
|
||||||
def get_ado_client(base_url: str, token: str) -> WorkItemTrackingClient:
|
|
||||||
connection = Connection(base_url=base_url, creds=BasicAuthentication("PAT", token))
|
|
||||||
client = connection.clients_v6_0.get_work_item_tracking_client()
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl=60)
|
|
||||||
def get_valid_fields(
|
|
||||||
client: WorkItemTrackingClient, project: Optional[str] = None
|
|
||||||
) -> List[str]:
|
|
||||||
valid_fields = [
|
|
||||||
x.reference_name.lower()
|
|
||||||
for x in client.get_fields(project=project, expand="ExtensionFields")
|
|
||||||
]
|
|
||||||
return valid_fields
|
|
||||||
|
|
||||||
|
|
||||||
class ADO:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
container: Container,
|
|
||||||
filename: str,
|
|
||||||
config: ADOTemplate,
|
|
||||||
report: Report,
|
|
||||||
*,
|
|
||||||
renderer: Optional[Render] = None,
|
|
||||||
):
|
|
||||||
self.config = config
|
|
||||||
if renderer:
|
|
||||||
self.renderer = renderer
|
|
||||||
else:
|
|
||||||
self.renderer = Render(container, filename, report)
|
|
||||||
self.project = self.render(self.config.project)
|
|
||||||
|
|
||||||
def connect(self) -> None:
|
|
||||||
auth_token = get_secret_string_value(self.config.auth_token)
|
|
||||||
self.client = get_ado_client(self.config.base_url, auth_token)
|
|
||||||
|
|
||||||
def render(self, template: str) -> str:
|
|
||||||
return self.renderer.render(template)
|
|
||||||
|
|
||||||
def existing_work_items(self) -> Iterator[WorkItem]:
|
|
||||||
filters = {}
|
|
||||||
for key in self.config.unique_fields:
|
|
||||||
if key == "System.TeamProject":
|
|
||||||
value = self.render(self.config.project)
|
|
||||||
else:
|
|
||||||
value = self.render(self.config.ado_fields[key])
|
|
||||||
filters[key.lower()] = value
|
|
||||||
|
|
||||||
valid_fields = get_valid_fields(
|
|
||||||
self.client, project=filters.get("system.teamproject")
|
|
||||||
)
|
|
||||||
|
|
||||||
post_query_filter = {}
|
|
||||||
|
|
||||||
# WIQL (Work Item Query Language) is an SQL like query language that
|
|
||||||
# doesn't support query params, safe quoting, or any other SQL-injection
|
|
||||||
# protection mechanisms.
|
|
||||||
#
|
|
||||||
# As such, build the WIQL with a those fields we can pre-determine are
|
|
||||||
# "safe" and otherwise use post-query filtering.
|
|
||||||
parts = []
|
|
||||||
for k, v in filters.items():
|
|
||||||
# Only add pre-system approved fields to the query
|
|
||||||
if k not in valid_fields:
|
|
||||||
post_query_filter[k] = v
|
|
||||||
continue
|
|
||||||
|
|
||||||
# WIQL supports wrapping values in ' or " and escaping ' by doubling it
|
|
||||||
#
|
|
||||||
# For this System.Title: hi'there
|
|
||||||
# use this query fragment: [System.Title] = 'hi''there'
|
|
||||||
#
|
|
||||||
# For this System.Title: hi"there
|
|
||||||
# use this query fragment: [System.Title] = 'hi"there'
|
|
||||||
#
|
|
||||||
# For this System.Title: hi'"there
|
|
||||||
# use this query fragment: [System.Title] = 'hi''"there'
|
|
||||||
SINGLE = "'"
|
|
||||||
parts.append("[%s] = '%s'" % (k, v.replace(SINGLE, SINGLE + SINGLE)))
|
|
||||||
|
|
||||||
query = "select [System.Id] from WorkItems"
|
|
||||||
if parts:
|
|
||||||
query += " where " + " AND ".join(parts)
|
|
||||||
|
|
||||||
wiql = Wiql(query=query)
|
|
||||||
for entry in self.client.query_by_wiql(wiql).work_items:
|
|
||||||
item = self.client.get_work_item(entry.id, expand="Fields")
|
|
||||||
lowered_fields = {x.lower(): str(y) for (x, y) in item.fields.items()}
|
|
||||||
if post_query_filter and not all(
|
|
||||||
[
|
|
||||||
k.lower() in lowered_fields and lowered_fields[k.lower()] == v
|
|
||||||
for (k, v) in post_query_filter.items()
|
|
||||||
]
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
yield item
|
|
||||||
|
|
||||||
def update_existing(self, item: WorkItem, notification_info: str) -> None:
|
|
||||||
if self.config.on_duplicate.comment:
|
|
||||||
comment = self.render(self.config.on_duplicate.comment)
|
|
||||||
self.client.add_comment(
|
|
||||||
CommentCreate(comment),
|
|
||||||
self.project,
|
|
||||||
item.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
document = []
|
|
||||||
for field in self.config.on_duplicate.increment:
|
|
||||||
value = int(item.fields[field]) if field in item.fields else 0
|
|
||||||
value += 1
|
|
||||||
document.append(
|
|
||||||
JsonPatchOperation(
|
|
||||||
op="Replace", path="/fields/%s" % field, value=str(value)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
for field in self.config.on_duplicate.ado_fields:
|
|
||||||
field_value = self.render(self.config.on_duplicate.ado_fields[field])
|
|
||||||
document.append(
|
|
||||||
JsonPatchOperation(
|
|
||||||
op="Replace", path="/fields/%s" % field, value=field_value
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if item.fields["System.State"] in self.config.on_duplicate.set_state:
|
|
||||||
document.append(
|
|
||||||
JsonPatchOperation(
|
|
||||||
op="Replace",
|
|
||||||
path="/fields/System.State",
|
|
||||||
value=self.config.on_duplicate.set_state[
|
|
||||||
item.fields["System.State"]
|
|
||||||
],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if document:
|
|
||||||
self.client.update_work_item(document, item.id, project=self.project)
|
|
||||||
logging.info(
|
|
||||||
f"notify ado: updated work item {item.id} - {notification_info}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logging.info(
|
|
||||||
f"notify ado: no update for work item {item.id} - {notification_info}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def render_new(self) -> Tuple[str, List[JsonPatchOperation]]:
|
|
||||||
task_type = self.render(self.config.type)
|
|
||||||
document = []
|
|
||||||
if "System.Tags" not in self.config.ado_fields:
|
|
||||||
document.append(
|
|
||||||
JsonPatchOperation(
|
|
||||||
op="Add", path="/fields/System.Tags", value="Onefuzz"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
for field in self.config.ado_fields:
|
|
||||||
value = self.render(self.config.ado_fields[field])
|
|
||||||
if field == "System.Tags":
|
|
||||||
value += ";Onefuzz"
|
|
||||||
document.append(
|
|
||||||
JsonPatchOperation(op="Add", path="/fields/%s" % field, value=value)
|
|
||||||
)
|
|
||||||
return (task_type, document)
|
|
||||||
|
|
||||||
def create_new(self) -> WorkItem:
|
|
||||||
task_type, document = self.render_new()
|
|
||||||
|
|
||||||
entry = self.client.create_work_item(
|
|
||||||
document=document, project=self.project, type=task_type
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.config.comment:
|
|
||||||
comment = self.render(self.config.comment)
|
|
||||||
self.client.add_comment(
|
|
||||||
CommentCreate(comment),
|
|
||||||
self.project,
|
|
||||||
entry.id,
|
|
||||||
)
|
|
||||||
return entry
|
|
||||||
|
|
||||||
def process(self, notification_info: str) -> None:
|
|
||||||
seen = False
|
|
||||||
for work_item in self.existing_work_items():
|
|
||||||
self.update_existing(work_item, notification_info)
|
|
||||||
seen = True
|
|
||||||
|
|
||||||
if not seen:
|
|
||||||
entry = self.create_new()
|
|
||||||
logging.info(
|
|
||||||
"notify ado: created new work item" f" {entry.id} - {notification_info}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def is_transient(err: Exception) -> bool:
|
|
||||||
error_codes = [
|
|
||||||
# "TF401349: An unexpected error has occurred, please verify your request and try again." # noqa: E501
|
|
||||||
"TF401349",
|
|
||||||
# TF26071: This work item has been changed by someone else since you opened it. You will need to refresh it and discard your changes. # noqa: E501
|
|
||||||
"TF26071",
|
|
||||||
]
|
|
||||||
error_str = str(err)
|
|
||||||
for code in error_codes:
|
|
||||||
if code in error_str:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def notify_ado(
|
|
||||||
config: ADOTemplate,
|
|
||||||
container: Container,
|
|
||||||
filename: str,
|
|
||||||
report: Union[Report, RegressionReport],
|
|
||||||
fail_task_on_transient_error: bool,
|
|
||||||
notification_id: UUID,
|
|
||||||
) -> None:
|
|
||||||
if isinstance(report, RegressionReport):
|
|
||||||
logging.info(
|
|
||||||
"ado integration does not support regression reports. "
|
|
||||||
"container:%s filename:%s",
|
|
||||||
container,
|
|
||||||
filename,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
notification_info = (
|
|
||||||
f"job_id:{report.job_id} task_id:{report.task_id}"
|
|
||||||
f" container:{container} filename:{filename}"
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("notify ado: %s", notification_info)
|
|
||||||
|
|
||||||
try:
|
|
||||||
ado = ADO(container, filename, config, report)
|
|
||||||
ado.connect()
|
|
||||||
ado.process(notification_info)
|
|
||||||
except (
|
|
||||||
AzureDevOpsAuthenticationError,
|
|
||||||
AzureDevOpsClientError,
|
|
||||||
AzureDevOpsServiceError,
|
|
||||||
AzureDevOpsClientRequestError,
|
|
||||||
ValueError,
|
|
||||||
) as err:
|
|
||||||
|
|
||||||
if not fail_task_on_transient_error and is_transient(err):
|
|
||||||
raise AdoNotificationException(
|
|
||||||
f"transient ADO notification failure {notification_info}"
|
|
||||||
) from err
|
|
||||||
else:
|
|
||||||
log_failed_notification(report, err, notification_id)
|
|
||||||
raise AdoNotificationException(
|
|
||||||
"Sending file changed event for notification %s to poison queue"
|
|
||||||
% notification_id
|
|
||||||
) from err
|
|
@ -1,96 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Optional
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from jinja2.sandbox import SandboxedEnvironment
|
|
||||||
from onefuzztypes.models import Report
|
|
||||||
from onefuzztypes.primitives import Container
|
|
||||||
|
|
||||||
from ..azure.containers import auth_download_url
|
|
||||||
from ..azure.creds import get_instance_url
|
|
||||||
from ..jobs import Job
|
|
||||||
from ..tasks.config import get_setup_container
|
|
||||||
from ..tasks.main import Task
|
|
||||||
|
|
||||||
|
|
||||||
def log_failed_notification(
|
|
||||||
report: Report, error: Exception, notification_id: UUID
|
|
||||||
) -> None:
|
|
||||||
logging.error(
|
|
||||||
"notification failed: notification_id:%s job_id:%s task_id:%s err:%s",
|
|
||||||
notification_id,
|
|
||||||
report.job_id,
|
|
||||||
report.task_id,
|
|
||||||
error,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Render:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
container: Container,
|
|
||||||
filename: str,
|
|
||||||
report: Report,
|
|
||||||
*,
|
|
||||||
task: Optional[Task] = None,
|
|
||||||
job: Optional[Job] = None,
|
|
||||||
target_url: Optional[str] = None,
|
|
||||||
input_url: Optional[str] = None,
|
|
||||||
report_url: Optional[str] = None,
|
|
||||||
):
|
|
||||||
self.report = report
|
|
||||||
self.container = container
|
|
||||||
self.filename = filename
|
|
||||||
if not task:
|
|
||||||
task = Task.get(report.job_id, report.task_id)
|
|
||||||
if not task:
|
|
||||||
raise ValueError(f"invalid task {report.task_id}")
|
|
||||||
if not job:
|
|
||||||
job = Job.get(report.job_id)
|
|
||||||
if not job:
|
|
||||||
raise ValueError(f"invalid job {report.job_id}")
|
|
||||||
|
|
||||||
self.task_config = task.config
|
|
||||||
self.job_config = job.config
|
|
||||||
self.env = SandboxedEnvironment()
|
|
||||||
|
|
||||||
self.target_url = target_url
|
|
||||||
if not self.target_url:
|
|
||||||
setup_container = get_setup_container(task.config)
|
|
||||||
if setup_container:
|
|
||||||
self.target_url = auth_download_url(
|
|
||||||
setup_container, self.report.executable.replace("setup/", "", 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
if report_url:
|
|
||||||
self.report_url = report_url
|
|
||||||
else:
|
|
||||||
self.report_url = auth_download_url(container, filename)
|
|
||||||
|
|
||||||
self.input_url = input_url
|
|
||||||
if not self.input_url:
|
|
||||||
if self.report.input_blob:
|
|
||||||
self.input_url = auth_download_url(
|
|
||||||
self.report.input_blob.container, self.report.input_blob.name
|
|
||||||
)
|
|
||||||
|
|
||||||
def render(self, template: str) -> str:
|
|
||||||
return self.env.from_string(template).render(
|
|
||||||
{
|
|
||||||
"report": self.report,
|
|
||||||
"task": self.task_config,
|
|
||||||
"job": self.job_config,
|
|
||||||
"report_url": self.report_url,
|
|
||||||
"input_url": self.input_url,
|
|
||||||
"target_url": self.target_url,
|
|
||||||
"report_container": self.container,
|
|
||||||
"report_filename": self.filename,
|
|
||||||
"repro_cmd": "onefuzz --endpoint %s repro create_and_connect %s %s"
|
|
||||||
% (get_instance_url(), self.container, self.filename),
|
|
||||||
}
|
|
||||||
)
|
|
@ -1,138 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import List, Optional, Union
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from github3 import login
|
|
||||||
from github3.exceptions import GitHubException
|
|
||||||
from github3.issues import Issue
|
|
||||||
from onefuzztypes.enums import GithubIssueSearchMatch
|
|
||||||
from onefuzztypes.models import (
|
|
||||||
GithubAuth,
|
|
||||||
GithubIssueTemplate,
|
|
||||||
RegressionReport,
|
|
||||||
Report,
|
|
||||||
)
|
|
||||||
from onefuzztypes.primitives import Container
|
|
||||||
|
|
||||||
from ..secrets import get_secret_obj
|
|
||||||
from .common import Render, log_failed_notification
|
|
||||||
|
|
||||||
|
|
||||||
class GithubIssue:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: GithubIssueTemplate,
|
|
||||||
container: Container,
|
|
||||||
filename: str,
|
|
||||||
report: Report,
|
|
||||||
):
|
|
||||||
self.config = config
|
|
||||||
self.report = report
|
|
||||||
if isinstance(config.auth.secret, GithubAuth):
|
|
||||||
auth = config.auth.secret
|
|
||||||
else:
|
|
||||||
auth = get_secret_obj(config.auth.secret.url, GithubAuth)
|
|
||||||
|
|
||||||
self.gh = login(username=auth.user, password=auth.personal_access_token)
|
|
||||||
self.renderer = Render(container, filename, report)
|
|
||||||
|
|
||||||
def render(self, field: str) -> str:
|
|
||||||
return self.renderer.render(field)
|
|
||||||
|
|
||||||
def existing(self) -> List[Issue]:
|
|
||||||
query = [
|
|
||||||
self.render(self.config.unique_search.string),
|
|
||||||
"repo:%s/%s"
|
|
||||||
% (
|
|
||||||
self.render(self.config.organization),
|
|
||||||
self.render(self.config.repository),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
if self.config.unique_search.author:
|
|
||||||
query.append("author:%s" % self.render(self.config.unique_search.author))
|
|
||||||
|
|
||||||
if self.config.unique_search.state:
|
|
||||||
query.append("state:%s" % self.config.unique_search.state.name)
|
|
||||||
|
|
||||||
issues = []
|
|
||||||
title = self.render(self.config.title)
|
|
||||||
body = self.render(self.config.body)
|
|
||||||
for issue in self.gh.search_issues(" ".join(query)):
|
|
||||||
skip = False
|
|
||||||
for field in self.config.unique_search.field_match:
|
|
||||||
if field == GithubIssueSearchMatch.title and issue.title != title:
|
|
||||||
skip = True
|
|
||||||
break
|
|
||||||
if field == GithubIssueSearchMatch.body and issue.body != body:
|
|
||||||
skip = True
|
|
||||||
break
|
|
||||||
if not skip:
|
|
||||||
issues.append(issue)
|
|
||||||
|
|
||||||
return issues
|
|
||||||
|
|
||||||
def update(self, issue: Issue) -> None:
|
|
||||||
logging.info("updating issue: %s", issue)
|
|
||||||
if self.config.on_duplicate.comment:
|
|
||||||
issue.issue.create_comment(self.render(self.config.on_duplicate.comment))
|
|
||||||
if self.config.on_duplicate.labels:
|
|
||||||
labels = [self.render(x) for x in self.config.on_duplicate.labels]
|
|
||||||
issue.issue.edit(labels=labels)
|
|
||||||
if self.config.on_duplicate.reopen and issue.state != "open":
|
|
||||||
issue.issue.edit(state="open")
|
|
||||||
|
|
||||||
def create(self) -> None:
|
|
||||||
logging.info("creating issue")
|
|
||||||
|
|
||||||
assignees = [self.render(x) for x in self.config.assignees]
|
|
||||||
labels = list(set(["OneFuzz"] + [self.render(x) for x in self.config.labels]))
|
|
||||||
|
|
||||||
self.gh.create_issue(
|
|
||||||
self.render(self.config.organization),
|
|
||||||
self.render(self.config.repository),
|
|
||||||
self.render(self.config.title),
|
|
||||||
body=self.render(self.config.body),
|
|
||||||
labels=labels,
|
|
||||||
assignees=assignees,
|
|
||||||
)
|
|
||||||
|
|
||||||
def process(self) -> None:
|
|
||||||
issues = self.existing()
|
|
||||||
if issues:
|
|
||||||
self.update(issues[0])
|
|
||||||
else:
|
|
||||||
self.create()
|
|
||||||
|
|
||||||
|
|
||||||
def github_issue(
|
|
||||||
config: GithubIssueTemplate,
|
|
||||||
container: Container,
|
|
||||||
filename: str,
|
|
||||||
report: Optional[Union[Report, RegressionReport]],
|
|
||||||
notification_id: UUID,
|
|
||||||
) -> None:
|
|
||||||
if report is None:
|
|
||||||
return
|
|
||||||
if isinstance(report, RegressionReport):
|
|
||||||
logging.info(
|
|
||||||
"github issue integration does not support regression reports. "
|
|
||||||
"container:%s filename:%s",
|
|
||||||
container,
|
|
||||||
filename,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
handler = GithubIssue(config, container, filename, report)
|
|
||||||
handler.process()
|
|
||||||
except (GitHubException, ValueError) as err:
|
|
||||||
log_failed_notification(report, err, notification_id)
|
|
||||||
raise GitHubException(
|
|
||||||
"Sending file change event for notification %s to poison queue"
|
|
||||||
% notification_id
|
|
||||||
) from err
|
|
@ -1,200 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import List, Optional, Sequence, Tuple
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from memoization import cached
|
|
||||||
from onefuzztypes import models
|
|
||||||
from onefuzztypes.enums import ErrorCode, TaskState
|
|
||||||
from onefuzztypes.events import (
|
|
||||||
EventCrashReported,
|
|
||||||
EventFileAdded,
|
|
||||||
EventRegressionReported,
|
|
||||||
)
|
|
||||||
from onefuzztypes.models import (
|
|
||||||
ADOTemplate,
|
|
||||||
Error,
|
|
||||||
GithubIssueTemplate,
|
|
||||||
NotificationTemplate,
|
|
||||||
RegressionReport,
|
|
||||||
Report,
|
|
||||||
Result,
|
|
||||||
TeamsTemplate,
|
|
||||||
)
|
|
||||||
from onefuzztypes.primitives import Container
|
|
||||||
|
|
||||||
from ..azure.containers import container_exists, get_file_sas_url
|
|
||||||
from ..azure.queue import send_message
|
|
||||||
from ..azure.storage import StorageType
|
|
||||||
from ..events import send_event
|
|
||||||
from ..orm import ORMMixin
|
|
||||||
from ..reports import get_report_or_regression
|
|
||||||
from ..tasks.config import get_input_container_queues
|
|
||||||
from ..tasks.main import Task
|
|
||||||
from .ado import notify_ado
|
|
||||||
from .github_issues import github_issue
|
|
||||||
from .teams import notify_teams
|
|
||||||
|
|
||||||
|
|
||||||
class Notification(models.Notification, ORMMixin):
|
|
||||||
@classmethod
|
|
||||||
def get_by_id(cls, notification_id: UUID) -> Result["Notification"]:
|
|
||||||
notifications = cls.search(query={"notification_id": [notification_id]})
|
|
||||||
if not notifications:
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.INVALID_REQUEST, errors=["unable to find Notification"]
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(notifications) != 1:
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.INVALID_REQUEST,
|
|
||||||
errors=["error identifying Notification"],
|
|
||||||
)
|
|
||||||
notification = notifications[0]
|
|
||||||
return notification
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def key_fields(cls) -> Tuple[str, str]:
|
|
||||||
return ("notification_id", "container")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(
|
|
||||||
cls, container: Container, config: NotificationTemplate, replace_existing: bool
|
|
||||||
) -> Result["Notification"]:
|
|
||||||
if not container_exists(container, StorageType.corpus):
|
|
||||||
return Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid container"])
|
|
||||||
|
|
||||||
if replace_existing:
|
|
||||||
existing = cls.search(query={"container": [container]})
|
|
||||||
for entry in existing:
|
|
||||||
logging.info(
|
|
||||||
"replacing existing notification: %s - %s",
|
|
||||||
entry.notification_id,
|
|
||||||
container,
|
|
||||||
)
|
|
||||||
entry.delete()
|
|
||||||
|
|
||||||
entry = cls(container=container, config=config)
|
|
||||||
entry.save()
|
|
||||||
logging.info(
|
|
||||||
"created notification. notification_id:%s container:%s",
|
|
||||||
entry.notification_id,
|
|
||||||
entry.container,
|
|
||||||
)
|
|
||||||
return entry
|
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl=10)
|
|
||||||
def get_notifications(container: Container) -> List[Notification]:
|
|
||||||
return Notification.search(query={"container": [container]})
|
|
||||||
|
|
||||||
|
|
||||||
def get_regression_report_task(report: RegressionReport) -> Optional[Task]:
|
|
||||||
# crash_test_result is required, but report & no_repro are not
|
|
||||||
if report.crash_test_result.crash_report:
|
|
||||||
return Task.get(
|
|
||||||
report.crash_test_result.crash_report.job_id,
|
|
||||||
report.crash_test_result.crash_report.task_id,
|
|
||||||
)
|
|
||||||
if report.crash_test_result.no_repro:
|
|
||||||
return Task.get(
|
|
||||||
report.crash_test_result.no_repro.job_id,
|
|
||||||
report.crash_test_result.no_repro.task_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.error(
|
|
||||||
"unable to find crash_report or no_repro entry for report: %s",
|
|
||||||
report.json(include_none=False),
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl=10)
|
|
||||||
def get_queue_tasks() -> Sequence[Tuple[Task, Sequence[str]]]:
|
|
||||||
results = []
|
|
||||||
for task in Task.search_states(states=TaskState.available()):
|
|
||||||
containers = get_input_container_queues(task.config)
|
|
||||||
if containers:
|
|
||||||
results.append((task, containers))
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def new_files(
|
|
||||||
container: Container, filename: str, fail_task_on_transient_error: bool
|
|
||||||
) -> None:
|
|
||||||
notifications = get_notifications(container)
|
|
||||||
|
|
||||||
report = get_report_or_regression(
|
|
||||||
container, filename, expect_reports=bool(notifications)
|
|
||||||
)
|
|
||||||
|
|
||||||
if notifications:
|
|
||||||
done = []
|
|
||||||
for notification in notifications:
|
|
||||||
# ignore duplicate configurations
|
|
||||||
if notification.config in done:
|
|
||||||
continue
|
|
||||||
done.append(notification.config)
|
|
||||||
|
|
||||||
if isinstance(notification.config, TeamsTemplate):
|
|
||||||
notify_teams(
|
|
||||||
notification.config,
|
|
||||||
container,
|
|
||||||
filename,
|
|
||||||
report,
|
|
||||||
notification.notification_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not report:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if isinstance(notification.config, ADOTemplate):
|
|
||||||
notify_ado(
|
|
||||||
notification.config,
|
|
||||||
container,
|
|
||||||
filename,
|
|
||||||
report,
|
|
||||||
fail_task_on_transient_error,
|
|
||||||
notification.notification_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(notification.config, GithubIssueTemplate):
|
|
||||||
github_issue(
|
|
||||||
notification.config,
|
|
||||||
container,
|
|
||||||
filename,
|
|
||||||
report,
|
|
||||||
notification.notification_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
for (task, containers) in get_queue_tasks():
|
|
||||||
if container in containers:
|
|
||||||
logging.info("queuing input %s %s %s", container, filename, task.task_id)
|
|
||||||
url = get_file_sas_url(
|
|
||||||
container, filename, StorageType.corpus, read=True, delete=True
|
|
||||||
)
|
|
||||||
send_message(task.task_id, bytes(url, "utf-8"), StorageType.corpus)
|
|
||||||
|
|
||||||
if isinstance(report, Report):
|
|
||||||
crash_report_event = EventCrashReported(
|
|
||||||
report=report, container=container, filename=filename
|
|
||||||
)
|
|
||||||
report_task = Task.get(report.job_id, report.task_id)
|
|
||||||
if report_task:
|
|
||||||
crash_report_event.task_config = report_task.config
|
|
||||||
send_event(crash_report_event)
|
|
||||||
elif isinstance(report, RegressionReport):
|
|
||||||
regression_event = EventRegressionReported(
|
|
||||||
regression_report=report, container=container, filename=filename
|
|
||||||
)
|
|
||||||
|
|
||||||
report_task = get_regression_report_task(report)
|
|
||||||
if report_task:
|
|
||||||
regression_event.task_config = report_task.config
|
|
||||||
send_event(regression_event)
|
|
||||||
else:
|
|
||||||
send_event(EventFileAdded(container=container, filename=filename))
|
|
@ -1,141 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
import requests
|
|
||||||
from onefuzztypes.models import RegressionReport, Report, TeamsTemplate
|
|
||||||
from onefuzztypes.primitives import Container
|
|
||||||
|
|
||||||
from ..azure.containers import auth_download_url
|
|
||||||
from ..secrets import get_secret_string_value
|
|
||||||
from ..tasks.config import get_setup_container
|
|
||||||
from ..tasks.main import Task
|
|
||||||
|
|
||||||
|
|
||||||
def markdown_escape(data: str) -> str:
|
|
||||||
values = r"\\*_{}[]()#+-.!" # noqa: P103
|
|
||||||
for value in values:
|
|
||||||
data = data.replace(value, "\\" + value)
|
|
||||||
data = data.replace("`", "``")
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def code_block(data: str) -> str:
|
|
||||||
data = data.replace("`", "``")
|
|
||||||
return "\n```\n%s\n```\n" % data
|
|
||||||
|
|
||||||
|
|
||||||
def send_teams_webhook(
|
|
||||||
config: TeamsTemplate,
|
|
||||||
title: str,
|
|
||||||
facts: List[Dict[str, str]],
|
|
||||||
text: Optional[str],
|
|
||||||
notification_id: UUID,
|
|
||||||
) -> None:
|
|
||||||
title = markdown_escape(title)
|
|
||||||
|
|
||||||
message: Dict[str, Any] = {
|
|
||||||
"@type": "MessageCard",
|
|
||||||
"@context": "https://schema.org/extensions",
|
|
||||||
"summary": title,
|
|
||||||
"sections": [{"activityTitle": title, "facts": facts}],
|
|
||||||
}
|
|
||||||
|
|
||||||
if text:
|
|
||||||
message["sections"].append({"text": text})
|
|
||||||
|
|
||||||
config_url = get_secret_string_value(config.url)
|
|
||||||
response = requests.post(config_url, json=message)
|
|
||||||
if not response.ok:
|
|
||||||
logging.error(
|
|
||||||
"webhook failed notification_id:%s %s %s",
|
|
||||||
notification_id,
|
|
||||||
response.status_code,
|
|
||||||
response.content,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def notify_teams(
|
|
||||||
config: TeamsTemplate,
|
|
||||||
container: Container,
|
|
||||||
filename: str,
|
|
||||||
report: Optional[Union[Report, RegressionReport]],
|
|
||||||
notification_id: UUID,
|
|
||||||
) -> None:
|
|
||||||
text = None
|
|
||||||
facts: List[Dict[str, str]] = []
|
|
||||||
|
|
||||||
if isinstance(report, Report):
|
|
||||||
task = Task.get(report.job_id, report.task_id)
|
|
||||||
if not task:
|
|
||||||
logging.error(
|
|
||||||
"report with invalid task %s:%s", report.job_id, report.task_id
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
title = "new crash in %s: %s @ %s" % (
|
|
||||||
report.executable,
|
|
||||||
report.crash_type,
|
|
||||||
report.crash_site,
|
|
||||||
)
|
|
||||||
|
|
||||||
links = [
|
|
||||||
"[report](%s)" % auth_download_url(container, filename),
|
|
||||||
]
|
|
||||||
|
|
||||||
setup_container = get_setup_container(task.config)
|
|
||||||
if setup_container:
|
|
||||||
links.append(
|
|
||||||
"[executable](%s)"
|
|
||||||
% auth_download_url(
|
|
||||||
setup_container,
|
|
||||||
report.executable.replace("setup/", "", 1),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if report.input_blob:
|
|
||||||
links.append(
|
|
||||||
"[input](%s)"
|
|
||||||
% auth_download_url(
|
|
||||||
report.input_blob.container, report.input_blob.name
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
facts += [
|
|
||||||
{"name": "Files", "value": " | ".join(links)},
|
|
||||||
{
|
|
||||||
"name": "Task",
|
|
||||||
"value": markdown_escape(
|
|
||||||
"job_id: %s task_id: %s" % (report.job_id, report.task_id)
|
|
||||||
),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Repro",
|
|
||||||
"value": code_block(
|
|
||||||
"onefuzz repro create_and_connect %s %s" % (container, filename)
|
|
||||||
),
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
text = "## Call Stack\n" + "\n".join(code_block(x) for x in report.call_stack)
|
|
||||||
|
|
||||||
else:
|
|
||||||
title = "new file found"
|
|
||||||
facts += [
|
|
||||||
{
|
|
||||||
"name": "file",
|
|
||||||
"value": "[%s/%s](%s)"
|
|
||||||
% (
|
|
||||||
markdown_escape(container),
|
|
||||||
markdown_escape(filename),
|
|
||||||
auth_download_url(container, filename),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
send_teams_webhook(config, title, facts, text, notification_id)
|
|
@ -1,505 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import inspect
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from datetime import datetime
|
|
||||||
from enum import Enum
|
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Mapping,
|
|
||||||
Optional,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from azure.common import AzureConflictHttpError, AzureMissingResourceHttpError
|
|
||||||
from onefuzztypes.enums import (
|
|
||||||
ErrorCode,
|
|
||||||
JobState,
|
|
||||||
NodeState,
|
|
||||||
PoolState,
|
|
||||||
ScalesetState,
|
|
||||||
TaskState,
|
|
||||||
TelemetryEvent,
|
|
||||||
UpdateType,
|
|
||||||
VmState,
|
|
||||||
)
|
|
||||||
from onefuzztypes.models import Error, SecretData
|
|
||||||
from onefuzztypes.primitives import Container, PoolName, Region
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from typing_extensions import Protocol
|
|
||||||
|
|
||||||
from .azure.table import get_client
|
|
||||||
from .secrets import delete_remote_secret_data, save_to_keyvault
|
|
||||||
from .telemetry import track_event_filtered
|
|
||||||
from .updates import queue_update
|
|
||||||
|
|
||||||
A = TypeVar("A", bound="ORMMixin")
|
|
||||||
|
|
||||||
QUERY_VALUE_TYPES = Union[
|
|
||||||
List[bool],
|
|
||||||
List[int],
|
|
||||||
List[str],
|
|
||||||
List[UUID],
|
|
||||||
List[Region],
|
|
||||||
List[Container],
|
|
||||||
List[PoolName],
|
|
||||||
List[VmState],
|
|
||||||
List[ScalesetState],
|
|
||||||
List[JobState],
|
|
||||||
List[TaskState],
|
|
||||||
List[PoolState],
|
|
||||||
List[NodeState],
|
|
||||||
]
|
|
||||||
QueryFilter = Dict[str, QUERY_VALUE_TYPES]
|
|
||||||
|
|
||||||
SAFE_STRINGS = (UUID, Container, Region, PoolName)
|
|
||||||
KEY = Union[int, str, UUID, Enum]
|
|
||||||
|
|
||||||
HOURS = 60 * 60
|
|
||||||
|
|
||||||
|
|
||||||
class HasState(Protocol):
|
|
||||||
# TODO: this should be bound tighter than Any
|
|
||||||
# In the end, we want this to be an Enum. Specifically, one of
|
|
||||||
# the JobState,TaskState,etc enums.
|
|
||||||
state: Any
|
|
||||||
|
|
||||||
def get_keys(self) -> Tuple[KEY, KEY]:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def process_state_update(obj: HasState) -> None:
|
|
||||||
"""
|
|
||||||
process a single state update, if the obj
|
|
||||||
implements a function for that state
|
|
||||||
"""
|
|
||||||
|
|
||||||
func = getattr(obj, obj.state.name, None)
|
|
||||||
if func is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
keys = obj.get_keys()
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
"processing state update: %s - %s - %s", type(obj), keys, obj.state.name
|
|
||||||
)
|
|
||||||
func()
|
|
||||||
|
|
||||||
|
|
||||||
def process_state_updates(obj: HasState, max_updates: int = 5) -> None:
|
|
||||||
"""process through the state machine for an object"""
|
|
||||||
|
|
||||||
for _ in range(max_updates):
|
|
||||||
state = obj.state
|
|
||||||
process_state_update(obj)
|
|
||||||
new_state = obj.state
|
|
||||||
if new_state == state:
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
def resolve(key: KEY) -> str:
|
|
||||||
if isinstance(key, str):
|
|
||||||
return key
|
|
||||||
elif isinstance(key, UUID):
|
|
||||||
return str(key)
|
|
||||||
elif isinstance(key, Enum):
|
|
||||||
return key.name
|
|
||||||
elif isinstance(key, int):
|
|
||||||
return str(key)
|
|
||||||
raise NotImplementedError("unsupported type %s - %s" % (type(key), repr(key)))
|
|
||||||
|
|
||||||
|
|
||||||
def build_filters(
|
|
||||||
cls: Type[A], query_args: Optional[QueryFilter]
|
|
||||||
) -> Tuple[Optional[str], QueryFilter]:
|
|
||||||
if not query_args:
|
|
||||||
return (None, {})
|
|
||||||
|
|
||||||
partition_key_field, row_key_field = cls.key_fields()
|
|
||||||
|
|
||||||
search_filter_parts = []
|
|
||||||
post_filters: QueryFilter = {}
|
|
||||||
|
|
||||||
for field, values in query_args.items():
|
|
||||||
if field not in cls.__fields__:
|
|
||||||
raise ValueError("unexpected field %s: %s" % (repr(field), cls))
|
|
||||||
|
|
||||||
if not values:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if field == partition_key_field:
|
|
||||||
field_name = "PartitionKey"
|
|
||||||
elif field == row_key_field:
|
|
||||||
field_name = "RowKey"
|
|
||||||
else:
|
|
||||||
field_name = field
|
|
||||||
|
|
||||||
parts: Optional[List[str]] = None
|
|
||||||
if isinstance(values[0], bool):
|
|
||||||
parts = []
|
|
||||||
for x in values:
|
|
||||||
if not isinstance(x, bool):
|
|
||||||
raise TypeError("unexpected type")
|
|
||||||
parts.append("%s eq %s" % (field_name, str(x).lower()))
|
|
||||||
elif isinstance(values[0], int):
|
|
||||||
parts = []
|
|
||||||
for x in values:
|
|
||||||
if not isinstance(x, int):
|
|
||||||
raise TypeError("unexpected type")
|
|
||||||
parts.append("%s eq %d" % (field_name, x))
|
|
||||||
elif isinstance(values[0], Enum):
|
|
||||||
parts = []
|
|
||||||
for x in values:
|
|
||||||
if not isinstance(x, Enum):
|
|
||||||
raise TypeError("unexpected type")
|
|
||||||
parts.append("%s eq '%s'" % (field_name, x.name))
|
|
||||||
elif all(isinstance(x, SAFE_STRINGS) for x in values):
|
|
||||||
parts = ["%s eq '%s'" % (field_name, x) for x in values]
|
|
||||||
else:
|
|
||||||
post_filters[field_name] = values
|
|
||||||
|
|
||||||
if parts:
|
|
||||||
if len(parts) == 1:
|
|
||||||
search_filter_parts.append(parts[0])
|
|
||||||
else:
|
|
||||||
search_filter_parts.append("(" + " or ".join(parts) + ")")
|
|
||||||
|
|
||||||
if search_filter_parts:
|
|
||||||
return (" and ".join(search_filter_parts), post_filters)
|
|
||||||
|
|
||||||
return (None, post_filters)
|
|
||||||
|
|
||||||
|
|
||||||
def post_filter(value: Any, filters: Optional[QueryFilter]) -> bool:
|
|
||||||
if not filters:
|
|
||||||
return True
|
|
||||||
|
|
||||||
for field in filters:
|
|
||||||
if field not in value:
|
|
||||||
return False
|
|
||||||
if value[field] not in filters[field]:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
MappingIntStrAny = Mapping[Union[int, str], Any]
|
|
||||||
# A = TypeVar("A", bound="Model")
|
|
||||||
|
|
||||||
|
|
||||||
class ModelMixin(BaseModel):
|
|
||||||
def export_exclude(self) -> Optional[MappingIntStrAny]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def raw(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
by_alias: bool = False,
|
|
||||||
exclude_none: bool = False,
|
|
||||||
exclude: MappingIntStrAny = None,
|
|
||||||
include: MappingIntStrAny = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
# cycling through json means all wrapped types get resolved, such as UUID
|
|
||||||
result: Dict[str, Any] = json.loads(
|
|
||||||
self.json(
|
|
||||||
by_alias=by_alias,
|
|
||||||
exclude_none=exclude_none,
|
|
||||||
exclude=exclude,
|
|
||||||
include=include,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
B = TypeVar("B", bound=BaseModel)
|
|
||||||
|
|
||||||
|
|
||||||
def hide_secrets(data: B, hider: Callable[[SecretData], SecretData]) -> B:
|
|
||||||
for field in data.__fields__:
|
|
||||||
field_data = getattr(data, field)
|
|
||||||
|
|
||||||
if isinstance(field_data, SecretData):
|
|
||||||
field_data = hider(field_data)
|
|
||||||
elif isinstance(field_data, BaseModel):
|
|
||||||
field_data = hide_secrets(field_data, hider)
|
|
||||||
elif isinstance(field_data, list):
|
|
||||||
field_data = [
|
|
||||||
hide_secrets(x, hider) if isinstance(x, BaseModel) else x
|
|
||||||
for x in field_data
|
|
||||||
]
|
|
||||||
elif isinstance(field_data, dict):
|
|
||||||
for key in field_data:
|
|
||||||
if not isinstance(field_data[key], BaseModel):
|
|
||||||
continue
|
|
||||||
field_data[key] = hide_secrets(field_data[key], hider)
|
|
||||||
|
|
||||||
setattr(data, field, field_data)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
# NOTE: the actual deletion must come from the `deleter` callback function
|
|
||||||
def delete_secrets(data: B, deleter: Callable[[SecretData], None]) -> None:
|
|
||||||
for field in data.__fields__:
|
|
||||||
field_data = getattr(data, field)
|
|
||||||
if isinstance(field_data, SecretData):
|
|
||||||
deleter(field_data)
|
|
||||||
elif isinstance(field_data, BaseModel):
|
|
||||||
delete_secrets(field_data, deleter)
|
|
||||||
elif isinstance(field_data, list):
|
|
||||||
for entry in field_data:
|
|
||||||
if isinstance(entry, BaseModel):
|
|
||||||
delete_secrets(entry, deleter)
|
|
||||||
elif isinstance(entry, SecretData):
|
|
||||||
deleter(entry)
|
|
||||||
elif isinstance(field_data, dict):
|
|
||||||
for value in field_data.values():
|
|
||||||
if isinstance(value, BaseModel):
|
|
||||||
delete_secrets(value, deleter)
|
|
||||||
elif isinstance(value, SecretData):
|
|
||||||
deleter(value)
|
|
||||||
|
|
||||||
|
|
||||||
# NOTE: if you want to include Timestamp in a model that uses ORMMixin,
|
|
||||||
# it must be maintained as part of the model.
|
|
||||||
class ORMMixin(ModelMixin):
|
|
||||||
etag: Optional[str]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def table_name(cls: Type[A]) -> str:
|
|
||||||
return cls.__name__
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get(
|
|
||||||
cls: Type[A], PartitionKey: KEY, RowKey: Optional[KEY] = None
|
|
||||||
) -> Optional[A]:
|
|
||||||
client = get_client(table=cls.table_name())
|
|
||||||
partition_key = resolve(PartitionKey)
|
|
||||||
row_key = resolve(RowKey) if RowKey else partition_key
|
|
||||||
|
|
||||||
try:
|
|
||||||
raw = client.get_entity(cls.table_name(), partition_key, row_key)
|
|
||||||
except AzureMissingResourceHttpError:
|
|
||||||
return None
|
|
||||||
return cls.load(raw)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def key_fields(cls) -> Tuple[str, Optional[str]]:
|
|
||||||
raise NotImplementedError("keys not defined")
|
|
||||||
|
|
||||||
# FILTERS:
|
|
||||||
# The following
|
|
||||||
# * save_exclude: Specify fields to *exclude* from saving to Storage Tables
|
|
||||||
# * export_exclude: Specify the fields to *exclude* from sending to an external API
|
|
||||||
# * telemetry_include: Specify the fields to *include* for telemetry
|
|
||||||
#
|
|
||||||
# For implementation details see:
|
|
||||||
# https://pydantic-docs.helpmanual.io/usage/exporting_models/#advanced-include-and-exclude
|
|
||||||
def save_exclude(self) -> Optional[MappingIntStrAny]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def export_exclude(self) -> Optional[MappingIntStrAny]:
|
|
||||||
return {"etag": ...}
|
|
||||||
|
|
||||||
def telemetry_include(self) -> Optional[MappingIntStrAny]:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def telemetry(self) -> Any:
|
|
||||||
return self.raw(exclude_none=True, include=self.telemetry_include())
|
|
||||||
|
|
||||||
def get_keys(self) -> Tuple[KEY, KEY]:
|
|
||||||
partition_key_field, row_key_field = self.key_fields()
|
|
||||||
|
|
||||||
partition_key = getattr(self, partition_key_field)
|
|
||||||
if row_key_field:
|
|
||||||
row_key = getattr(self, row_key_field)
|
|
||||||
else:
|
|
||||||
row_key = partition_key
|
|
||||||
|
|
||||||
return (partition_key, row_key)
|
|
||||||
|
|
||||||
def save(self, new: bool = False, require_etag: bool = False) -> Optional[Error]:
|
|
||||||
self = hide_secrets(self, save_to_keyvault)
|
|
||||||
# TODO: migrate to an inspect.signature() model
|
|
||||||
raw = self.raw(by_alias=True, exclude_none=True, exclude=self.save_exclude())
|
|
||||||
for key in raw:
|
|
||||||
if not isinstance(raw[key], (str, int)):
|
|
||||||
raw[key] = json.dumps(raw[key])
|
|
||||||
|
|
||||||
for field in self.__fields__:
|
|
||||||
if field not in raw:
|
|
||||||
continue
|
|
||||||
# for datetime fields that passed through filtering, use the real value,
|
|
||||||
# rather than a serialized form
|
|
||||||
if self.__fields__[field].type_ == datetime:
|
|
||||||
raw[field] = getattr(self, field)
|
|
||||||
|
|
||||||
partition_key_field, row_key_field = self.key_fields()
|
|
||||||
|
|
||||||
# PartitionKey and RowKey must be 'str'
|
|
||||||
raw["PartitionKey"] = resolve(raw[partition_key_field])
|
|
||||||
raw["RowKey"] = resolve(raw[row_key_field or partition_key_field])
|
|
||||||
|
|
||||||
del raw[partition_key_field]
|
|
||||||
if row_key_field in raw:
|
|
||||||
del raw[row_key_field]
|
|
||||||
|
|
||||||
client = get_client(table=self.table_name())
|
|
||||||
|
|
||||||
# never save the timestamp
|
|
||||||
if "Timestamp" in raw:
|
|
||||||
del raw["Timestamp"]
|
|
||||||
|
|
||||||
if new:
|
|
||||||
try:
|
|
||||||
self.etag = client.insert_entity(self.table_name(), raw)
|
|
||||||
except AzureConflictHttpError:
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_CREATE, errors=["entry already exists"]
|
|
||||||
)
|
|
||||||
elif self.etag and require_etag:
|
|
||||||
self.etag = client.replace_entity(
|
|
||||||
self.table_name(), raw, if_match=self.etag
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.etag = client.insert_or_replace_entity(self.table_name(), raw)
|
|
||||||
|
|
||||||
if self.table_name() in TelemetryEvent.__members__:
|
|
||||||
telem = self.telemetry()
|
|
||||||
if telem:
|
|
||||||
track_event_filtered(TelemetryEvent[self.table_name()], telem)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def delete(self) -> None:
|
|
||||||
partition_key, row_key = self.get_keys()
|
|
||||||
|
|
||||||
delete_secrets(self, delete_remote_secret_data)
|
|
||||||
|
|
||||||
client = get_client()
|
|
||||||
try:
|
|
||||||
client.delete_entity(
|
|
||||||
self.table_name(), resolve(partition_key), resolve(row_key)
|
|
||||||
)
|
|
||||||
except AzureMissingResourceHttpError:
|
|
||||||
# It's OK if the component is already deleted
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(cls: Type[A], data: Dict[str, Union[str, bytes, bytearray]]) -> A:
|
|
||||||
partition_key_field, row_key_field = cls.key_fields()
|
|
||||||
|
|
||||||
if partition_key_field in data:
|
|
||||||
raise Exception(
|
|
||||||
"duplicate PartitionKey field %s for %s"
|
|
||||||
% (partition_key_field, cls.table_name())
|
|
||||||
)
|
|
||||||
if row_key_field in data:
|
|
||||||
raise Exception(
|
|
||||||
"duplicate RowKey field %s for %s" % (row_key_field, cls.table_name())
|
|
||||||
)
|
|
||||||
|
|
||||||
data[partition_key_field] = data["PartitionKey"]
|
|
||||||
if row_key_field is not None:
|
|
||||||
data[row_key_field] = data["RowKey"]
|
|
||||||
|
|
||||||
del data["PartitionKey"]
|
|
||||||
del data["RowKey"]
|
|
||||||
|
|
||||||
for key in inspect.signature(cls).parameters:
|
|
||||||
if key not in data:
|
|
||||||
continue
|
|
||||||
|
|
||||||
annotation = inspect.signature(cls).parameters[key].annotation
|
|
||||||
|
|
||||||
if inspect.isclass(annotation):
|
|
||||||
if (
|
|
||||||
issubclass(annotation, BaseModel)
|
|
||||||
or issubclass(annotation, dict)
|
|
||||||
or issubclass(annotation, list)
|
|
||||||
):
|
|
||||||
data[key] = json.loads(data[key])
|
|
||||||
continue
|
|
||||||
|
|
||||||
if getattr(annotation, "__origin__", None) == Union and any(
|
|
||||||
inspect.isclass(x) and issubclass(x, BaseModel)
|
|
||||||
for x in annotation.__args__
|
|
||||||
):
|
|
||||||
data[key] = json.loads(data[key])
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Required for Python >=3.7. In 3.6, a `Dict[_,_]` and `List[_]` annotations
|
|
||||||
# are a class according to `inspect.isclass`.
|
|
||||||
if getattr(annotation, "__origin__", None) in [dict, list]:
|
|
||||||
data[key] = json.loads(data[key])
|
|
||||||
continue
|
|
||||||
|
|
||||||
return cls.parse_obj(data)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def search(
|
|
||||||
cls: Type[A],
|
|
||||||
*,
|
|
||||||
query: Optional[QueryFilter] = None,
|
|
||||||
raw_unchecked_filter: Optional[str] = None,
|
|
||||||
num_results: int = None,
|
|
||||||
) -> List[A]:
|
|
||||||
search_filter, post_filters = build_filters(cls, query)
|
|
||||||
|
|
||||||
if raw_unchecked_filter is not None:
|
|
||||||
if search_filter is None:
|
|
||||||
search_filter = raw_unchecked_filter
|
|
||||||
else:
|
|
||||||
search_filter = "(%s) and (%s)" % (search_filter, raw_unchecked_filter)
|
|
||||||
|
|
||||||
client = get_client(table=cls.table_name())
|
|
||||||
entries = []
|
|
||||||
for row in client.query_entities(
|
|
||||||
cls.table_name(), filter=search_filter, num_results=num_results
|
|
||||||
):
|
|
||||||
if not post_filter(row, post_filters):
|
|
||||||
continue
|
|
||||||
|
|
||||||
entry = cls.load(row)
|
|
||||||
entries.append(entry)
|
|
||||||
return entries
|
|
||||||
|
|
||||||
def queue(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
method: Optional[Callable] = None,
|
|
||||||
visibility_timeout: Optional[int] = None,
|
|
||||||
) -> None:
|
|
||||||
if not hasattr(self, "state"):
|
|
||||||
raise NotImplementedError("Queued an ORM mapping without State")
|
|
||||||
|
|
||||||
update_type = UpdateType.__members__.get(type(self).__name__)
|
|
||||||
if update_type is None:
|
|
||||||
raise NotImplementedError("unsupported update type: %s" % self)
|
|
||||||
|
|
||||||
method_name: Optional[str] = None
|
|
||||||
if method is not None:
|
|
||||||
if not hasattr(method, "__name__"):
|
|
||||||
raise Exception("unable to queue method: %s" % method)
|
|
||||||
method_name = method.__name__
|
|
||||||
|
|
||||||
partition_key, row_key = self.get_keys()
|
|
||||||
|
|
||||||
queue_update(
|
|
||||||
update_type,
|
|
||||||
resolve(partition_key),
|
|
||||||
resolve(row_key),
|
|
||||||
method=method_name,
|
|
||||||
visibility_timeout=visibility_timeout,
|
|
||||||
)
|
|
@ -1,346 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import datetime
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
from uuid import UUID, uuid4
|
|
||||||
|
|
||||||
import base58
|
|
||||||
from azure.mgmt.compute.models import VirtualMachine
|
|
||||||
from onefuzztypes.enums import ErrorCode, VmState
|
|
||||||
from onefuzztypes.events import (
|
|
||||||
EventProxyCreated,
|
|
||||||
EventProxyDeleted,
|
|
||||||
EventProxyFailed,
|
|
||||||
EventProxyStateUpdated,
|
|
||||||
)
|
|
||||||
from onefuzztypes.models import (
|
|
||||||
Authentication,
|
|
||||||
Error,
|
|
||||||
Forward,
|
|
||||||
ProxyConfig,
|
|
||||||
ProxyHeartbeat,
|
|
||||||
)
|
|
||||||
from onefuzztypes.primitives import Container, Region
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from .__version__ import __version__
|
|
||||||
from .azure.auth import build_auth
|
|
||||||
from .azure.containers import get_file_sas_url, save_blob
|
|
||||||
from .azure.creds import get_instance_id
|
|
||||||
from .azure.ip import get_public_ip
|
|
||||||
from .azure.nsg import NSG
|
|
||||||
from .azure.queue import get_queue_sas
|
|
||||||
from .azure.storage import StorageType
|
|
||||||
from .azure.vm import VM
|
|
||||||
from .config import InstanceConfig
|
|
||||||
from .events import send_event
|
|
||||||
from .extension import proxy_manager_extensions
|
|
||||||
from .orm import ORMMixin, QueryFilter
|
|
||||||
from .proxy_forward import ProxyForward
|
|
||||||
|
|
||||||
PROXY_LOG_PREFIX = "scaleset-proxy: "
|
|
||||||
PROXY_LIFESPAN = datetime.timedelta(days=7)
|
|
||||||
|
|
||||||
|
|
||||||
# This isn't intended to ever be shared to the client, hence not being in
|
|
||||||
# onefuzztypes
|
|
||||||
class Proxy(ORMMixin):
|
|
||||||
timestamp: Optional[datetime.datetime] = Field(alias="Timestamp")
|
|
||||||
created_timestamp: datetime.datetime = Field(
|
|
||||||
default_factory=datetime.datetime.utcnow
|
|
||||||
)
|
|
||||||
proxy_id: UUID = Field(default_factory=uuid4)
|
|
||||||
region: Region
|
|
||||||
state: VmState = Field(default=VmState.init)
|
|
||||||
auth: Authentication = Field(default_factory=build_auth)
|
|
||||||
ip: Optional[str]
|
|
||||||
error: Optional[Error]
|
|
||||||
version: str = Field(default=__version__)
|
|
||||||
heartbeat: Optional[ProxyHeartbeat]
|
|
||||||
outdated: bool = Field(default=False)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def key_fields(cls) -> Tuple[str, Optional[str]]:
|
|
||||||
return ("region", "proxy_id")
|
|
||||||
|
|
||||||
def get_vm(self, config: InstanceConfig) -> VM:
|
|
||||||
config = InstanceConfig.fetch()
|
|
||||||
sku = config.proxy_vm_sku
|
|
||||||
tags = None
|
|
||||||
if config.vm_tags:
|
|
||||||
tags = config.vm_tags
|
|
||||||
vm = VM(
|
|
||||||
name="proxy-%s" % base58.b58encode(self.proxy_id.bytes).decode(),
|
|
||||||
region=self.region,
|
|
||||||
sku=sku,
|
|
||||||
image=config.default_linux_vm_image,
|
|
||||||
auth=self.auth,
|
|
||||||
tags=tags,
|
|
||||||
)
|
|
||||||
return vm
|
|
||||||
|
|
||||||
def init(self) -> None:
|
|
||||||
config = InstanceConfig.fetch()
|
|
||||||
vm = self.get_vm(config)
|
|
||||||
vm_data = vm.get()
|
|
||||||
if vm_data:
|
|
||||||
if vm_data.provisioning_state == "Failed":
|
|
||||||
self.set_provision_failed(vm_data)
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
self.save_proxy_config()
|
|
||||||
self.set_state(VmState.extensions_launch)
|
|
||||||
else:
|
|
||||||
nsg = NSG(
|
|
||||||
name=self.region,
|
|
||||||
region=self.region,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = nsg.create()
|
|
||||||
if isinstance(result, Error):
|
|
||||||
self.set_failed(result)
|
|
||||||
return
|
|
||||||
|
|
||||||
nsg_config = config.proxy_nsg_config
|
|
||||||
result = nsg.set_allowed_sources(nsg_config)
|
|
||||||
if isinstance(result, Error):
|
|
||||||
self.set_failed(result)
|
|
||||||
return
|
|
||||||
|
|
||||||
vm.nsg = nsg
|
|
||||||
|
|
||||||
result = vm.create()
|
|
||||||
if isinstance(result, Error):
|
|
||||||
self.set_failed(result)
|
|
||||||
return
|
|
||||||
self.save()
|
|
||||||
|
|
||||||
def set_provision_failed(self, vm_data: VirtualMachine) -> None:
|
|
||||||
errors = ["provisioning failed"]
|
|
||||||
for status in vm_data.instance_view.statuses:
|
|
||||||
if status.level.name.lower() == "error":
|
|
||||||
errors.append(
|
|
||||||
f"code:{status.code} status:{status.display_status} "
|
|
||||||
f"message:{status.message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.set_failed(
|
|
||||||
Error(
|
|
||||||
code=ErrorCode.PROXY_FAILED,
|
|
||||||
errors=errors,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
def set_failed(self, error: Error) -> None:
|
|
||||||
if self.error is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
logging.error(PROXY_LOG_PREFIX + "vm failed: %s - %s", self.region, error)
|
|
||||||
send_event(
|
|
||||||
EventProxyFailed(region=self.region, proxy_id=self.proxy_id, error=error)
|
|
||||||
)
|
|
||||||
self.error = error
|
|
||||||
self.set_state(VmState.stopping)
|
|
||||||
|
|
||||||
def extensions_launch(self) -> None:
|
|
||||||
config = InstanceConfig.fetch()
|
|
||||||
vm = self.get_vm(config)
|
|
||||||
vm_data = vm.get()
|
|
||||||
if not vm_data:
|
|
||||||
self.set_failed(
|
|
||||||
Error(
|
|
||||||
code=ErrorCode.PROXY_FAILED,
|
|
||||||
errors=["azure not able to find vm"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if vm_data.provisioning_state == "Failed":
|
|
||||||
self.set_provision_failed(vm_data)
|
|
||||||
return
|
|
||||||
|
|
||||||
ip = get_public_ip(vm_data.network_profile.network_interfaces[0].id)
|
|
||||||
if ip is None:
|
|
||||||
self.save()
|
|
||||||
return
|
|
||||||
self.ip = ip
|
|
||||||
|
|
||||||
extensions = proxy_manager_extensions(self.region, self.proxy_id)
|
|
||||||
result = vm.add_extensions(extensions)
|
|
||||||
if isinstance(result, Error):
|
|
||||||
self.set_failed(result)
|
|
||||||
return
|
|
||||||
elif result:
|
|
||||||
self.set_state(VmState.running)
|
|
||||||
|
|
||||||
self.save()
|
|
||||||
|
|
||||||
def stopping(self) -> None:
|
|
||||||
config = InstanceConfig.fetch()
|
|
||||||
vm = self.get_vm(config)
|
|
||||||
if not vm.is_deleted():
|
|
||||||
logging.info(PROXY_LOG_PREFIX + "stopping proxy: %s", self.region)
|
|
||||||
vm.delete()
|
|
||||||
self.save()
|
|
||||||
else:
|
|
||||||
self.stopped()
|
|
||||||
|
|
||||||
def stopped(self) -> None:
|
|
||||||
self.set_state(VmState.stopped)
|
|
||||||
logging.info(PROXY_LOG_PREFIX + "removing proxy: %s", self.region)
|
|
||||||
send_event(EventProxyDeleted(region=self.region, proxy_id=self.proxy_id))
|
|
||||||
self.delete()
|
|
||||||
|
|
||||||
def is_outdated(self) -> bool:
|
|
||||||
if self.state not in VmState.available():
|
|
||||||
return True
|
|
||||||
|
|
||||||
if self.version != __version__:
|
|
||||||
logging.info(
|
|
||||||
PROXY_LOG_PREFIX + "mismatch version: proxy:%s service:%s state:%s",
|
|
||||||
self.version,
|
|
||||||
__version__,
|
|
||||||
self.state,
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
if self.created_timestamp is not None:
|
|
||||||
proxy_timestamp = self.created_timestamp
|
|
||||||
if proxy_timestamp < (
|
|
||||||
datetime.datetime.now(tz=datetime.timezone.utc) - PROXY_LIFESPAN
|
|
||||||
):
|
|
||||||
logging.info(
|
|
||||||
PROXY_LOG_PREFIX
|
|
||||||
+ "proxy older than 7 days:proxy-created:%s state:%s",
|
|
||||||
self.created_timestamp,
|
|
||||||
self.state,
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def is_used(self) -> bool:
|
|
||||||
if len(self.get_forwards()) == 0:
|
|
||||||
logging.info(PROXY_LOG_PREFIX + "no forwards: %s", self.region)
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def is_alive(self) -> bool:
|
|
||||||
# Unfortunately, with and without TZ information is required for compare
|
|
||||||
# or exceptions are generated
|
|
||||||
ten_minutes_ago_no_tz = datetime.datetime.utcnow() - datetime.timedelta(
|
|
||||||
minutes=10
|
|
||||||
)
|
|
||||||
ten_minutes_ago = ten_minutes_ago_no_tz.astimezone(datetime.timezone.utc)
|
|
||||||
if (
|
|
||||||
self.heartbeat is not None
|
|
||||||
and self.heartbeat.timestamp < ten_minutes_ago_no_tz
|
|
||||||
):
|
|
||||||
logging.error(
|
|
||||||
PROXY_LOG_PREFIX + "last heartbeat is more than an 10 minutes old: "
|
|
||||||
"%s - last heartbeat:%s compared_to:%s",
|
|
||||||
self.region,
|
|
||||||
self.heartbeat,
|
|
||||||
ten_minutes_ago_no_tz,
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
elif not self.heartbeat and self.timestamp and self.timestamp < ten_minutes_ago:
|
|
||||||
logging.error(
|
|
||||||
PROXY_LOG_PREFIX + "no heartbeat in the last 10 minutes: "
|
|
||||||
"%s timestamp: %s compared_to:%s",
|
|
||||||
self.region,
|
|
||||||
self.timestamp,
|
|
||||||
ten_minutes_ago,
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_forwards(self) -> List[Forward]:
|
|
||||||
forwards: List[Forward] = []
|
|
||||||
for entry in ProxyForward.search_forward(
|
|
||||||
region=self.region, proxy_id=self.proxy_id
|
|
||||||
):
|
|
||||||
if entry.endtime < datetime.datetime.now(tz=datetime.timezone.utc):
|
|
||||||
entry.delete()
|
|
||||||
else:
|
|
||||||
forwards.append(
|
|
||||||
Forward(
|
|
||||||
src_port=entry.port,
|
|
||||||
dst_ip=entry.dst_ip,
|
|
||||||
dst_port=entry.dst_port,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return forwards
|
|
||||||
|
|
||||||
def save_proxy_config(self) -> None:
|
|
||||||
forwards = self.get_forwards()
|
|
||||||
proxy_config = ProxyConfig(
|
|
||||||
url=get_file_sas_url(
|
|
||||||
Container("proxy-configs"),
|
|
||||||
"%s/%s/config.json" % (self.region, self.proxy_id),
|
|
||||||
StorageType.config,
|
|
||||||
read=True,
|
|
||||||
),
|
|
||||||
notification=get_queue_sas(
|
|
||||||
"proxy",
|
|
||||||
StorageType.config,
|
|
||||||
add=True,
|
|
||||||
),
|
|
||||||
forwards=forwards,
|
|
||||||
region=self.region,
|
|
||||||
proxy_id=self.proxy_id,
|
|
||||||
instance_telemetry_key=os.environ.get("APPINSIGHTS_INSTRUMENTATIONKEY"),
|
|
||||||
microsoft_telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"),
|
|
||||||
instance_id=get_instance_id(),
|
|
||||||
)
|
|
||||||
|
|
||||||
save_blob(
|
|
||||||
Container("proxy-configs"),
|
|
||||||
"%s/%s/config.json" % (self.region, self.proxy_id),
|
|
||||||
proxy_config.json(),
|
|
||||||
StorageType.config,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def search_states(cls, *, states: Optional[List[VmState]] = None) -> List["Proxy"]:
|
|
||||||
query: QueryFilter = {}
|
|
||||||
if states:
|
|
||||||
query["state"] = states
|
|
||||||
return cls.search(query=query)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_or_create(cls, region: Region) -> Optional["Proxy"]:
|
|
||||||
proxy_list = Proxy.search(query={"region": [region], "outdated": [False]})
|
|
||||||
for proxy in proxy_list:
|
|
||||||
if proxy.is_outdated():
|
|
||||||
proxy.outdated = True
|
|
||||||
proxy.save()
|
|
||||||
continue
|
|
||||||
if proxy.state not in VmState.available():
|
|
||||||
continue
|
|
||||||
return proxy
|
|
||||||
|
|
||||||
logging.info(PROXY_LOG_PREFIX + "creating proxy: region:%s", region)
|
|
||||||
proxy = Proxy(region=region)
|
|
||||||
proxy.save()
|
|
||||||
send_event(EventProxyCreated(region=region, proxy_id=proxy.proxy_id))
|
|
||||||
return proxy
|
|
||||||
|
|
||||||
def set_state(self, state: VmState) -> None:
|
|
||||||
if self.state == state:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.state = state
|
|
||||||
self.save()
|
|
||||||
|
|
||||||
send_event(
|
|
||||||
EventProxyStateUpdated(
|
|
||||||
region=self.region, proxy_id=self.proxy_id, state=self.state
|
|
||||||
)
|
|
||||||
)
|
|
@ -1,143 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import datetime
|
|
||||||
import logging
|
|
||||||
from typing import List, Optional, Tuple, Union
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from onefuzztypes.enums import ErrorCode
|
|
||||||
from onefuzztypes.models import Error, Forward
|
|
||||||
from onefuzztypes.primitives import Region
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from .azure.ip import get_scaleset_instance_ip
|
|
||||||
from .orm import ORMMixin, QueryFilter
|
|
||||||
|
|
||||||
PORT_RANGES = range(28000, 32000)
|
|
||||||
|
|
||||||
|
|
||||||
# This isn't intended to ever be shared to the client, hence not being in
|
|
||||||
# onefuzztypes
|
|
||||||
class ProxyForward(ORMMixin):
|
|
||||||
region: Region
|
|
||||||
port: int
|
|
||||||
scaleset_id: UUID
|
|
||||||
machine_id: UUID
|
|
||||||
proxy_id: Optional[UUID]
|
|
||||||
dst_ip: str
|
|
||||||
dst_port: int
|
|
||||||
endtime: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def key_fields(cls) -> Tuple[str, str]:
|
|
||||||
return ("region", "port")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def update_or_create(
|
|
||||||
cls,
|
|
||||||
region: Region,
|
|
||||||
scaleset_id: UUID,
|
|
||||||
machine_id: UUID,
|
|
||||||
dst_port: int,
|
|
||||||
duration: int,
|
|
||||||
) -> Union["ProxyForward", Error]:
|
|
||||||
private_ip = get_scaleset_instance_ip(scaleset_id, machine_id)
|
|
||||||
if not private_ip:
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_PORT_FORWARD, errors=["no private ip for node"]
|
|
||||||
)
|
|
||||||
|
|
||||||
entries = cls.search_forward(
|
|
||||||
scaleset_id=scaleset_id,
|
|
||||||
machine_id=machine_id,
|
|
||||||
dst_port=dst_port,
|
|
||||||
region=region,
|
|
||||||
)
|
|
||||||
if entries:
|
|
||||||
entry = entries[0]
|
|
||||||
entry.endtime = datetime.datetime.utcnow() + datetime.timedelta(
|
|
||||||
hours=duration
|
|
||||||
)
|
|
||||||
entry.save()
|
|
||||||
return entry
|
|
||||||
|
|
||||||
existing = [int(x.port) for x in entries]
|
|
||||||
for port in PORT_RANGES:
|
|
||||||
if port in existing:
|
|
||||||
continue
|
|
||||||
|
|
||||||
entry = cls(
|
|
||||||
region=region,
|
|
||||||
port=port,
|
|
||||||
scaleset_id=scaleset_id,
|
|
||||||
machine_id=machine_id,
|
|
||||||
dst_ip=private_ip,
|
|
||||||
dst_port=dst_port,
|
|
||||||
endtime=datetime.datetime.utcnow() + datetime.timedelta(hours=duration),
|
|
||||||
)
|
|
||||||
result = entry.save(new=True)
|
|
||||||
if isinstance(result, Error):
|
|
||||||
logging.info("port is already used: %s", entry)
|
|
||||||
continue
|
|
||||||
|
|
||||||
return entry
|
|
||||||
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_PORT_FORWARD, errors=["all forward ports used"]
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def remove_forward(
|
|
||||||
cls,
|
|
||||||
scaleset_id: UUID,
|
|
||||||
*,
|
|
||||||
proxy_id: Optional[UUID] = None,
|
|
||||||
machine_id: Optional[UUID] = None,
|
|
||||||
dst_port: Optional[int] = None,
|
|
||||||
) -> List[Region]:
|
|
||||||
entries = cls.search_forward(
|
|
||||||
scaleset_id=scaleset_id,
|
|
||||||
machine_id=machine_id,
|
|
||||||
proxy_id=proxy_id,
|
|
||||||
dst_port=dst_port,
|
|
||||||
)
|
|
||||||
regions = set()
|
|
||||||
for entry in entries:
|
|
||||||
regions.add(entry.region)
|
|
||||||
entry.delete()
|
|
||||||
return list(regions)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def search_forward(
|
|
||||||
cls,
|
|
||||||
*,
|
|
||||||
scaleset_id: Optional[UUID] = None,
|
|
||||||
region: Optional[Region] = None,
|
|
||||||
machine_id: Optional[UUID] = None,
|
|
||||||
proxy_id: Optional[UUID] = None,
|
|
||||||
dst_port: Optional[int] = None,
|
|
||||||
) -> List["ProxyForward"]:
|
|
||||||
|
|
||||||
query: QueryFilter = {}
|
|
||||||
if region is not None:
|
|
||||||
query["region"] = [region]
|
|
||||||
|
|
||||||
if scaleset_id is not None:
|
|
||||||
query["scaleset_id"] = [scaleset_id]
|
|
||||||
|
|
||||||
if machine_id is not None:
|
|
||||||
query["machine_id"] = [machine_id]
|
|
||||||
|
|
||||||
if proxy_id is not None:
|
|
||||||
query["proxy_id"] = [proxy_id]
|
|
||||||
|
|
||||||
if dst_port is not None:
|
|
||||||
query["dst_port"] = [dst_port]
|
|
||||||
|
|
||||||
return cls.search(query=query)
|
|
||||||
|
|
||||||
def to_forward(self) -> Forward:
|
|
||||||
return Forward(src_port=self.port, dst_ip=self.dst_ip, dst_port=self.dst_port)
|
|
@ -1,165 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from sys import getsizeof
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
from memoization import cached
|
|
||||||
from onefuzztypes.models import RegressionReport, Report
|
|
||||||
from onefuzztypes.primitives import Container
|
|
||||||
from pydantic import ValidationError
|
|
||||||
|
|
||||||
from .azure.containers import get_blob
|
|
||||||
from .azure.storage import StorageType
|
|
||||||
|
|
||||||
|
|
||||||
# This is fix for the following error:
|
|
||||||
# Exception while executing function:
|
|
||||||
# Functions.queue_file_changes Result: Failure
|
|
||||||
# Exception: AzureHttpError: Bad Request
|
|
||||||
# "The property value exceeds the maximum allowed size (64KB).
|
|
||||||
# If the property value is a string, it is UTF-16 encoded and
|
|
||||||
# the maximum number of characters should be 32K or less.
|
|
||||||
def fix_report_size(
|
|
||||||
content: str,
|
|
||||||
report: Report,
|
|
||||||
acceptable_report_length_kb: int = 24,
|
|
||||||
keep_num_entries: int = 10,
|
|
||||||
keep_string_len: int = 256,
|
|
||||||
) -> Report:
|
|
||||||
logging.info(f"report content length {getsizeof(content)}")
|
|
||||||
if getsizeof(content) > acceptable_report_length_kb * 1024:
|
|
||||||
msg = f"report data exceeds {acceptable_report_length_kb}K {getsizeof(content)}"
|
|
||||||
if len(report.call_stack) > keep_num_entries:
|
|
||||||
msg = msg + "; removing some of stack frames from the report"
|
|
||||||
report.call_stack = report.call_stack[0:keep_num_entries] + ["..."]
|
|
||||||
|
|
||||||
if report.asan_log and len(report.asan_log) > keep_string_len:
|
|
||||||
msg = msg + "; removing some of asan log entries from the report"
|
|
||||||
report.asan_log = report.asan_log[0:keep_string_len] + "..."
|
|
||||||
|
|
||||||
if report.minimized_stack and len(report.minimized_stack) > keep_num_entries:
|
|
||||||
msg = msg + "; removing some of minimized stack frames from the report"
|
|
||||||
report.minimized_stack = report.minimized_stack[0:keep_num_entries] + [
|
|
||||||
"..."
|
|
||||||
]
|
|
||||||
|
|
||||||
if (
|
|
||||||
report.minimized_stack_function_names
|
|
||||||
and len(report.minimized_stack_function_names) > keep_num_entries
|
|
||||||
):
|
|
||||||
msg = (
|
|
||||||
msg
|
|
||||||
+ "; removing some of minimized stack function names from the report"
|
|
||||||
)
|
|
||||||
report.minimized_stack_function_names = (
|
|
||||||
report.minimized_stack_function_names[0:keep_num_entries] + ["..."]
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
report.minimized_stack_function_lines
|
|
||||||
and len(report.minimized_stack_function_lines) > keep_num_entries
|
|
||||||
):
|
|
||||||
msg = (
|
|
||||||
msg
|
|
||||||
+ "; removing some of minimized stack function lines from the report"
|
|
||||||
)
|
|
||||||
report.minimized_stack_function_lines = (
|
|
||||||
report.minimized_stack_function_lines[0:keep_num_entries] + ["..."]
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info(msg)
|
|
||||||
return report
|
|
||||||
|
|
||||||
|
|
||||||
def parse_report_or_regression(
|
|
||||||
content: Union[str, bytes],
|
|
||||||
file_path: Optional[str] = None,
|
|
||||||
expect_reports: bool = False,
|
|
||||||
) -> Optional[Union[Report, RegressionReport]]:
|
|
||||||
if isinstance(content, bytes):
|
|
||||||
try:
|
|
||||||
content = content.decode()
|
|
||||||
except UnicodeDecodeError as err:
|
|
||||||
if expect_reports:
|
|
||||||
logging.error(
|
|
||||||
f"unable to parse report ({file_path}): "
|
|
||||||
f"unicode decode of report failed - {err}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = json.loads(content)
|
|
||||||
except json.decoder.JSONDecodeError as err:
|
|
||||||
if expect_reports:
|
|
||||||
logging.error(
|
|
||||||
f"unable to parse report ({file_path}): json decoding failed - {err}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
regression_err = None
|
|
||||||
try:
|
|
||||||
regression_report = RegressionReport.parse_obj(data)
|
|
||||||
|
|
||||||
if (
|
|
||||||
regression_report.crash_test_result is not None
|
|
||||||
and regression_report.crash_test_result.crash_report is not None
|
|
||||||
):
|
|
||||||
regression_report.crash_test_result.crash_report = fix_report_size(
|
|
||||||
content, regression_report.crash_test_result.crash_report
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
regression_report.original_crash_test_result is not None
|
|
||||||
and regression_report.original_crash_test_result.crash_report is not None
|
|
||||||
):
|
|
||||||
regression_report.original_crash_test_result.crash_report = fix_report_size(
|
|
||||||
content, regression_report.original_crash_test_result.crash_report
|
|
||||||
)
|
|
||||||
return regression_report
|
|
||||||
except ValidationError as err:
|
|
||||||
regression_err = err
|
|
||||||
|
|
||||||
try:
|
|
||||||
report = Report.parse_obj(data)
|
|
||||||
return fix_report_size(content, report)
|
|
||||||
except ValidationError as err:
|
|
||||||
if expect_reports:
|
|
||||||
logging.error(
|
|
||||||
f"unable to parse report ({file_path}) as a report or regression. "
|
|
||||||
f"regression error: {regression_err} report error: {err}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# cache the last 1000 reports
|
|
||||||
@cached(max_size=1000)
|
|
||||||
def get_report_or_regression(
|
|
||||||
container: Container, filename: str, *, expect_reports: bool = False
|
|
||||||
) -> Optional[Union[Report, RegressionReport]]:
|
|
||||||
file_path = "/".join([container, filename])
|
|
||||||
if not filename.endswith(".json"):
|
|
||||||
if expect_reports:
|
|
||||||
logging.error("get_report invalid extension: %s", file_path)
|
|
||||||
return None
|
|
||||||
|
|
||||||
blob = get_blob(container, filename, StorageType.corpus)
|
|
||||||
if blob is None:
|
|
||||||
if expect_reports:
|
|
||||||
logging.error("get_report invalid blob: %s", file_path)
|
|
||||||
return None
|
|
||||||
|
|
||||||
return parse_report_or_regression(
|
|
||||||
blob, file_path=file_path, expect_reports=expect_reports
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_report(container: Container, filename: str) -> Optional[Report]:
|
|
||||||
result = get_report_or_regression(container, filename)
|
|
||||||
if isinstance(result, Report):
|
|
||||||
return result
|
|
||||||
return None
|
|
@ -1,286 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
from azure.mgmt.compute.models import VirtualMachine
|
|
||||||
from onefuzztypes.enums import OS, ContainerType, ErrorCode, VmState
|
|
||||||
from onefuzztypes.models import Error
|
|
||||||
from onefuzztypes.models import Repro as BASE_REPRO
|
|
||||||
from onefuzztypes.models import ReproConfig, TaskVm, UserInfo
|
|
||||||
from onefuzztypes.primitives import Container
|
|
||||||
|
|
||||||
from .azure.auth import build_auth
|
|
||||||
from .azure.containers import save_blob
|
|
||||||
from .azure.creds import get_base_region
|
|
||||||
from .azure.ip import get_public_ip
|
|
||||||
from .azure.nsg import NSG
|
|
||||||
from .azure.storage import StorageType
|
|
||||||
from .azure.vm import VM
|
|
||||||
from .config import InstanceConfig
|
|
||||||
from .extension import repro_extensions
|
|
||||||
from .orm import ORMMixin, QueryFilter
|
|
||||||
from .reports import get_report
|
|
||||||
from .tasks.main import Task
|
|
||||||
|
|
||||||
DEFAULT_SKU = "Standard_DS1_v2"
|
|
||||||
|
|
||||||
|
|
||||||
class Repro(BASE_REPRO, ORMMixin):
|
|
||||||
def set_error(self, error: Error) -> None:
|
|
||||||
logging.error(
|
|
||||||
"repro failed: vm_id: %s task_id: %s: error: %s",
|
|
||||||
self.vm_id,
|
|
||||||
self.task_id,
|
|
||||||
error,
|
|
||||||
)
|
|
||||||
self.error = error
|
|
||||||
self.state = VmState.stopping
|
|
||||||
self.save()
|
|
||||||
|
|
||||||
def get_vm(self, config: InstanceConfig) -> VM:
|
|
||||||
tags = None
|
|
||||||
if config.vm_tags:
|
|
||||||
tags = config.vm_tags
|
|
||||||
|
|
||||||
task = Task.get_by_task_id(self.task_id)
|
|
||||||
if isinstance(task, Error):
|
|
||||||
raise Exception("previously existing task missing: %s" % self.task_id)
|
|
||||||
|
|
||||||
config = InstanceConfig.fetch()
|
|
||||||
default_os = {
|
|
||||||
OS.linux: config.default_linux_vm_image,
|
|
||||||
OS.windows: config.default_windows_vm_image,
|
|
||||||
}
|
|
||||||
vm_config = task.get_repro_vm_config()
|
|
||||||
if vm_config is None:
|
|
||||||
# if using a pool without any scalesets defined yet, use reasonable defaults
|
|
||||||
if task.os not in default_os:
|
|
||||||
raise NotImplementedError("unsupported OS for repro %s" % task.os)
|
|
||||||
|
|
||||||
vm_config = TaskVm(
|
|
||||||
region=get_base_region(), sku=DEFAULT_SKU, image=default_os[task.os]
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.auth is None:
|
|
||||||
raise Exception("missing auth")
|
|
||||||
|
|
||||||
return VM(
|
|
||||||
name=self.vm_id,
|
|
||||||
region=vm_config.region,
|
|
||||||
sku=vm_config.sku,
|
|
||||||
image=vm_config.image,
|
|
||||||
auth=self.auth,
|
|
||||||
tags=tags,
|
|
||||||
)
|
|
||||||
|
|
||||||
def init(self) -> None:
|
|
||||||
config = InstanceConfig.fetch()
|
|
||||||
vm = self.get_vm(config)
|
|
||||||
vm_data = vm.get()
|
|
||||||
if vm_data:
|
|
||||||
if vm_data.provisioning_state == "Failed":
|
|
||||||
self.set_failed(vm)
|
|
||||||
else:
|
|
||||||
script_result = self.build_repro_script()
|
|
||||||
if isinstance(script_result, Error):
|
|
||||||
self.set_error(script_result)
|
|
||||||
return
|
|
||||||
|
|
||||||
self.state = VmState.extensions_launch
|
|
||||||
else:
|
|
||||||
nsg = NSG(
|
|
||||||
name=vm.region,
|
|
||||||
region=vm.region,
|
|
||||||
)
|
|
||||||
result = nsg.create()
|
|
||||||
if isinstance(result, Error):
|
|
||||||
self.set_failed(result)
|
|
||||||
return
|
|
||||||
|
|
||||||
nsg_config = config.proxy_nsg_config
|
|
||||||
result = nsg.set_allowed_sources(nsg_config)
|
|
||||||
if isinstance(result, Error):
|
|
||||||
self.set_failed(result)
|
|
||||||
return
|
|
||||||
|
|
||||||
vm.nsg = nsg
|
|
||||||
result = vm.create()
|
|
||||||
if isinstance(result, Error):
|
|
||||||
self.set_error(result)
|
|
||||||
return
|
|
||||||
self.save()
|
|
||||||
|
|
||||||
def set_failed(self, vm_data: VirtualMachine) -> None:
|
|
||||||
errors = []
|
|
||||||
for status in vm_data.instance_view.statuses:
|
|
||||||
if status.level.name.lower() == "error":
|
|
||||||
errors.append(
|
|
||||||
"%s %s %s" % (status.code, status.display_status, status.message)
|
|
||||||
)
|
|
||||||
return self.set_error(Error(code=ErrorCode.VM_CREATE_FAILED, errors=errors))
|
|
||||||
|
|
||||||
def get_setup_container(self) -> Optional[Container]:
|
|
||||||
task = Task.get_by_task_id(self.task_id)
|
|
||||||
if isinstance(task, Task):
|
|
||||||
for container in task.config.containers:
|
|
||||||
if container.type == ContainerType.setup:
|
|
||||||
return container.name
|
|
||||||
return None
|
|
||||||
|
|
||||||
def extensions_launch(self) -> None:
|
|
||||||
config = InstanceConfig.fetch()
|
|
||||||
vm = self.get_vm(config)
|
|
||||||
vm_data = vm.get()
|
|
||||||
if not vm_data:
|
|
||||||
self.set_error(
|
|
||||||
Error(
|
|
||||||
code=ErrorCode.VM_CREATE_FAILED,
|
|
||||||
errors=["failed before launching extensions"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if vm_data.provisioning_state == "Failed":
|
|
||||||
self.set_failed(vm_data)
|
|
||||||
return
|
|
||||||
|
|
||||||
if not self.ip:
|
|
||||||
self.ip = get_public_ip(vm_data.network_profile.network_interfaces[0].id)
|
|
||||||
|
|
||||||
extensions = repro_extensions(
|
|
||||||
vm.region, self.os, self.vm_id, self.config, self.get_setup_container()
|
|
||||||
)
|
|
||||||
result = vm.add_extensions(extensions)
|
|
||||||
if isinstance(result, Error):
|
|
||||||
self.set_error(result)
|
|
||||||
return
|
|
||||||
elif result:
|
|
||||||
self.state = VmState.running
|
|
||||||
|
|
||||||
self.save()
|
|
||||||
|
|
||||||
def stopping(self) -> None:
|
|
||||||
config = InstanceConfig.fetch()
|
|
||||||
vm = self.get_vm(config)
|
|
||||||
if not vm.is_deleted():
|
|
||||||
logging.info("vm stopping: %s", self.vm_id)
|
|
||||||
vm.delete()
|
|
||||||
self.save()
|
|
||||||
else:
|
|
||||||
self.stopped()
|
|
||||||
|
|
||||||
def stopped(self) -> None:
|
|
||||||
logging.info("vm stopped: %s", self.vm_id)
|
|
||||||
self.delete()
|
|
||||||
|
|
||||||
def build_repro_script(self) -> Optional[Error]:
|
|
||||||
if self.auth is None:
|
|
||||||
return Error(code=ErrorCode.VM_CREATE_FAILED, errors=["missing auth"])
|
|
||||||
|
|
||||||
task = Task.get_by_task_id(self.task_id)
|
|
||||||
if isinstance(task, Error):
|
|
||||||
return task
|
|
||||||
|
|
||||||
report = get_report(self.config.container, self.config.path)
|
|
||||||
if report is None:
|
|
||||||
return Error(code=ErrorCode.VM_CREATE_FAILED, errors=["missing report"])
|
|
||||||
|
|
||||||
if report.input_blob is None:
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.VM_CREATE_FAILED,
|
|
||||||
errors=["unable to perform repro for crash reports without inputs"],
|
|
||||||
)
|
|
||||||
|
|
||||||
files = {}
|
|
||||||
|
|
||||||
if task.os == OS.windows:
|
|
||||||
ssh_path = "$env:ProgramData/ssh/administrators_authorized_keys"
|
|
||||||
cmds = [
|
|
||||||
'Set-Content -Path %s -Value "%s"' % (ssh_path, self.auth.public_key),
|
|
||||||
". C:\\onefuzz\\tools\\win64\\onefuzz.ps1",
|
|
||||||
"Set-SetSSHACL",
|
|
||||||
'while (1) { cdb -server tcp:port=1337 -c "g" setup\\%s %s }'
|
|
||||||
% (
|
|
||||||
task.config.task.target_exe,
|
|
||||||
report.input_blob.name,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
cmd = "\r\n".join(cmds)
|
|
||||||
files["repro.ps1"] = cmd
|
|
||||||
elif task.os == OS.linux:
|
|
||||||
gdb_fmt = (
|
|
||||||
"ASAN_OPTIONS='abort_on_error=1' gdbserver "
|
|
||||||
"%s /onefuzz/setup/%s /onefuzz/downloaded/%s"
|
|
||||||
)
|
|
||||||
cmd = "while :; do %s; done" % (
|
|
||||||
gdb_fmt
|
|
||||||
% (
|
|
||||||
"localhost:1337",
|
|
||||||
task.config.task.target_exe,
|
|
||||||
report.input_blob.name,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
files["repro.sh"] = cmd
|
|
||||||
|
|
||||||
cmd = "#!/bin/bash\n%s" % (
|
|
||||||
gdb_fmt % ("-", task.config.task.target_exe, report.input_blob.name)
|
|
||||||
)
|
|
||||||
files["repro-stdout.sh"] = cmd
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("invalid task os: %s" % task.os)
|
|
||||||
|
|
||||||
for filename in files:
|
|
||||||
save_blob(
|
|
||||||
Container("repro-scripts"),
|
|
||||||
"%s/%s" % (self.vm_id, filename),
|
|
||||||
files[filename],
|
|
||||||
StorageType.config,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("saved repro script")
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def search_states(cls, *, states: Optional[List[VmState]] = None) -> List["Repro"]:
|
|
||||||
query: QueryFilter = {}
|
|
||||||
if states:
|
|
||||||
query["state"] = states
|
|
||||||
return cls.search(query=query)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(
|
|
||||||
cls, config: ReproConfig, user_info: Optional[UserInfo]
|
|
||||||
) -> Union[Error, "Repro"]:
|
|
||||||
report = get_report(config.container, config.path)
|
|
||||||
if not report:
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.UNABLE_TO_FIND, errors=["unable to find report"]
|
|
||||||
)
|
|
||||||
|
|
||||||
task = Task.get_by_task_id(report.task_id)
|
|
||||||
if isinstance(task, Error):
|
|
||||||
return task
|
|
||||||
|
|
||||||
vm = cls(config=config, task_id=task.task_id, os=task.os, auth=build_auth())
|
|
||||||
if vm.end_time is None:
|
|
||||||
vm.end_time = datetime.utcnow() + timedelta(hours=config.duration)
|
|
||||||
|
|
||||||
vm.user_info = user_info
|
|
||||||
vm.save()
|
|
||||||
|
|
||||||
return vm
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def search_expired(cls) -> List["Repro"]:
|
|
||||||
# unlike jobs/tasks, the entry is deleted from the backing table upon stop
|
|
||||||
time_filter = "end_time lt datetime'%s'" % datetime.utcnow().isoformat()
|
|
||||||
return cls.search(raw_unchecked_filter=time_filter)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def key_fields(cls) -> Tuple[str, Optional[str]]:
|
|
||||||
return ("vm_id", None)
|
|
@ -1,118 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import TYPE_CHECKING, Sequence, Type, TypeVar, Union
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from azure.functions import HttpRequest, HttpResponse
|
|
||||||
from onefuzztypes.enums import ErrorCode
|
|
||||||
from onefuzztypes.models import Error
|
|
||||||
from onefuzztypes.responses import BaseResponse
|
|
||||||
from pydantic import BaseModel # noqa: F401
|
|
||||||
from pydantic import ValidationError
|
|
||||||
|
|
||||||
from .orm import ModelMixin
|
|
||||||
|
|
||||||
# We don't actually use these types at runtime at this time. Rather,
|
|
||||||
# these are used in a bound TypeVar. MyPy suggests to only import these
|
|
||||||
# types during type checking.
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from onefuzztypes.requests import BaseRequest # noqa: F401
|
|
||||||
|
|
||||||
|
|
||||||
def ok(
|
|
||||||
data: Union[BaseResponse, Sequence[BaseResponse], ModelMixin, Sequence[ModelMixin]]
|
|
||||||
) -> HttpResponse:
|
|
||||||
if isinstance(data, BaseResponse):
|
|
||||||
return HttpResponse(data.json(exclude_none=True), mimetype="application/json")
|
|
||||||
|
|
||||||
if isinstance(data, list) and len(data) > 0 and isinstance(data[0], BaseResponse):
|
|
||||||
decoded = [json.loads(x.json(exclude_none=True)) for x in data]
|
|
||||||
return HttpResponse(json.dumps(decoded), mimetype="application/json")
|
|
||||||
|
|
||||||
if isinstance(data, ModelMixin):
|
|
||||||
return HttpResponse(
|
|
||||||
data.json(exclude_none=True, exclude=data.export_exclude()),
|
|
||||||
mimetype="application/json",
|
|
||||||
)
|
|
||||||
|
|
||||||
decoded = [
|
|
||||||
x.raw(exclude_none=True, exclude=x.export_exclude())
|
|
||||||
if isinstance(x, ModelMixin)
|
|
||||||
else x
|
|
||||||
for x in data
|
|
||||||
]
|
|
||||||
return HttpResponse(
|
|
||||||
json.dumps(decoded),
|
|
||||||
mimetype="application/json",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def not_ok(
|
|
||||||
error: Error, *, status_code: int = 400, context: Union[str, UUID]
|
|
||||||
) -> HttpResponse:
|
|
||||||
if 400 <= status_code <= 599:
|
|
||||||
logging.error("request error - %s: %s", str(context), error.json())
|
|
||||||
|
|
||||||
return HttpResponse(
|
|
||||||
error.json(), status_code=status_code, mimetype="application/json"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
"status code %s is not int the expected range [400; 599]" % status_code
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def redirect(url: str) -> HttpResponse:
|
|
||||||
return HttpResponse(status_code=302, headers={"Location": url})
|
|
||||||
|
|
||||||
|
|
||||||
def convert_error(err: ValidationError) -> Error:
|
|
||||||
errors = []
|
|
||||||
for error in err.errors():
|
|
||||||
if isinstance(error["loc"], tuple):
|
|
||||||
name = ".".join([str(x) for x in error["loc"]])
|
|
||||||
else:
|
|
||||||
name = str(error["loc"])
|
|
||||||
errors.append("%s: %s" % (name, error["msg"]))
|
|
||||||
return Error(code=ErrorCode.INVALID_REQUEST, errors=errors)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: loosen restrictions here during dev. We should be specific
|
|
||||||
# about only parsing things that are of a "Request" type, but there are
|
|
||||||
# a handful of types that need work in order to enforce that.
|
|
||||||
#
|
|
||||||
# These can be easily found by swapping the following comment and running
|
|
||||||
# mypy.
|
|
||||||
#
|
|
||||||
# A = TypeVar("A", bound="BaseRequest")
|
|
||||||
A = TypeVar("A", bound="BaseModel")
|
|
||||||
|
|
||||||
|
|
||||||
def parse_request(cls: Type[A], req: HttpRequest) -> Union[A, Error]:
|
|
||||||
try:
|
|
||||||
return cls.parse_obj(req.get_json())
|
|
||||||
except ValidationError as err:
|
|
||||||
return convert_error(err)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_uri(cls: Type[A], req: HttpRequest) -> Union[A, Error]:
|
|
||||||
data = {}
|
|
||||||
for key in req.params:
|
|
||||||
data[key] = req.params[key]
|
|
||||||
|
|
||||||
try:
|
|
||||||
return cls.parse_obj(data)
|
|
||||||
except ValidationError as err:
|
|
||||||
return convert_error(err)
|
|
||||||
|
|
||||||
|
|
||||||
class RequestException(Exception):
|
|
||||||
def __init__(self, error: Error):
|
|
||||||
self.error = error
|
|
||||||
message = "error %s" % error
|
|
||||||
super().__init__(message)
|
|
@ -1,109 +0,0 @@
|
|||||||
from typing import Dict, List, Optional
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from onefuzztypes.models import ApiAccessRule
|
|
||||||
|
|
||||||
|
|
||||||
class RuleConflictError(Exception):
|
|
||||||
def __init__(self, message: str) -> None:
|
|
||||||
super(RuleConflictError, self).__init__(message)
|
|
||||||
self.message = message
|
|
||||||
|
|
||||||
|
|
||||||
class RequestAccess:
|
|
||||||
"""
|
|
||||||
Stores the rules associated with a the request paths
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Rules:
|
|
||||||
allowed_groups_ids: List[UUID]
|
|
||||||
|
|
||||||
def __init__(self, allowed_groups_ids: List[UUID] = []) -> None:
|
|
||||||
self.allowed_groups_ids = allowed_groups_ids
|
|
||||||
|
|
||||||
class Node:
|
|
||||||
# http method -> rules
|
|
||||||
rules: Dict[str, "RequestAccess.Rules"]
|
|
||||||
# path -> node
|
|
||||||
children: Dict[str, "RequestAccess.Node"]
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.rules = {}
|
|
||||||
self.children = {}
|
|
||||||
pass
|
|
||||||
|
|
||||||
root: Node
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.root = RequestAccess.Node()
|
|
||||||
|
|
||||||
def __add_url__(self, methods: List[str], path: str, rules: Rules) -> None:
|
|
||||||
methods = list(map(lambda m: m.upper(), methods))
|
|
||||||
|
|
||||||
segments = [s for s in path.split("/") if s != ""]
|
|
||||||
if len(segments) == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
current_node = self.root
|
|
||||||
current_segment_index = 0
|
|
||||||
|
|
||||||
while current_segment_index < len(segments):
|
|
||||||
current_segment = segments[current_segment_index]
|
|
||||||
if current_segment in current_node.children:
|
|
||||||
current_node = current_node.children[current_segment]
|
|
||||||
current_segment_index = current_segment_index + 1
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
# we found a node matching this exact path
|
|
||||||
# This means that there is an existing rule causing a conflict
|
|
||||||
if current_segment_index == len(segments):
|
|
||||||
for method in methods:
|
|
||||||
if method in current_node.rules:
|
|
||||||
raise RuleConflictError(f"Conflicting rules on {method} {path}")
|
|
||||||
|
|
||||||
while current_segment_index < len(segments):
|
|
||||||
current_segment = segments[current_segment_index]
|
|
||||||
current_node.children[current_segment] = RequestAccess.Node()
|
|
||||||
current_node = current_node.children[current_segment]
|
|
||||||
current_segment_index = current_segment_index + 1
|
|
||||||
|
|
||||||
for method in methods:
|
|
||||||
current_node.rules[method] = rules
|
|
||||||
|
|
||||||
def get_matching_rules(self, method: str, path: str) -> Optional[Rules]:
|
|
||||||
method = method.upper()
|
|
||||||
segments = [s for s in path.split("/") if s != ""]
|
|
||||||
current_node = self.root
|
|
||||||
current_rule = None
|
|
||||||
|
|
||||||
if method in current_node.rules:
|
|
||||||
current_rule = current_node.rules[method]
|
|
||||||
|
|
||||||
current_segment_index = 0
|
|
||||||
|
|
||||||
while current_segment_index < len(segments):
|
|
||||||
current_segment = segments[current_segment_index]
|
|
||||||
if current_segment in current_node.children:
|
|
||||||
current_node = current_node.children[current_segment]
|
|
||||||
elif "*" in current_node.children:
|
|
||||||
current_node = current_node.children["*"]
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
if method in current_node.rules:
|
|
||||||
current_rule = current_node.rules[method]
|
|
||||||
current_segment_index = current_segment_index + 1
|
|
||||||
return current_rule
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def build(cls, rules: Dict[str, ApiAccessRule]) -> "RequestAccess":
|
|
||||||
request_access = RequestAccess()
|
|
||||||
for endpoint in rules:
|
|
||||||
rule = rules[endpoint]
|
|
||||||
request_access.__add_url__(
|
|
||||||
rule.methods,
|
|
||||||
endpoint,
|
|
||||||
RequestAccess.Rules(allowed_groups_ids=rule.allowed_groups),
|
|
||||||
)
|
|
||||||
|
|
||||||
return request_access
|
|
@ -1,87 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import Tuple, Type, TypeVar
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
from azure.keyvault.secrets import KeyVaultSecret
|
|
||||||
from onefuzztypes.models import SecretAddress, SecretData
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from .azure.creds import get_keyvault_client
|
|
||||||
|
|
||||||
A = TypeVar("A", bound=BaseModel)
|
|
||||||
|
|
||||||
|
|
||||||
def save_to_keyvault(secret_data: SecretData) -> SecretData:
|
|
||||||
if isinstance(secret_data.secret, SecretAddress):
|
|
||||||
return secret_data
|
|
||||||
|
|
||||||
secret_name = str(uuid4())
|
|
||||||
if isinstance(secret_data.secret, str):
|
|
||||||
secret_value = secret_data.secret
|
|
||||||
elif isinstance(secret_data.secret, BaseModel):
|
|
||||||
secret_value = secret_data.secret.json()
|
|
||||||
else:
|
|
||||||
raise Exception("invalid secret data")
|
|
||||||
|
|
||||||
kv = store_in_keyvault(get_keyvault_address(), secret_name, secret_value)
|
|
||||||
secret_data.secret = SecretAddress(url=kv.id)
|
|
||||||
return secret_data
|
|
||||||
|
|
||||||
|
|
||||||
def get_secret_string_value(self: SecretData[str]) -> str:
|
|
||||||
if isinstance(self.secret, SecretAddress):
|
|
||||||
secret = get_secret(self.secret.url)
|
|
||||||
return secret.value
|
|
||||||
else:
|
|
||||||
return self.secret
|
|
||||||
|
|
||||||
|
|
||||||
def get_keyvault_address() -> str:
|
|
||||||
# https://docs.microsoft.com/en-us/azure/key-vault/general/about-keys-secrets-certificates#vault-name-and-object-name
|
|
||||||
keyvault_name = os.environ["ONEFUZZ_KEYVAULT"]
|
|
||||||
return f"https://{keyvault_name}.vault.azure.net"
|
|
||||||
|
|
||||||
|
|
||||||
def store_in_keyvault(
|
|
||||||
keyvault_url: str, secret_name: str, secret_value: str
|
|
||||||
) -> KeyVaultSecret:
|
|
||||||
keyvault_client = get_keyvault_client(keyvault_url)
|
|
||||||
kvs: KeyVaultSecret = keyvault_client.set_secret(secret_name, secret_value)
|
|
||||||
return kvs
|
|
||||||
|
|
||||||
|
|
||||||
def parse_secret_url(secret_url: str) -> Tuple[str, str]:
|
|
||||||
# format: https://{vault-name}.vault.azure.net/secrets/{secret-name}/{version}
|
|
||||||
u = urlparse(secret_url)
|
|
||||||
vault_url = f"{u.scheme}://{u.netloc}"
|
|
||||||
secret_name = u.path.split("/")[2]
|
|
||||||
return (vault_url, secret_name)
|
|
||||||
|
|
||||||
|
|
||||||
def get_secret(secret_url: str) -> KeyVaultSecret:
|
|
||||||
(vault_url, secret_name) = parse_secret_url(secret_url)
|
|
||||||
keyvault_client = get_keyvault_client(vault_url)
|
|
||||||
return keyvault_client.get_secret(secret_name)
|
|
||||||
|
|
||||||
|
|
||||||
def get_secret_obj(secret_url: str, model: Type[A]) -> A:
|
|
||||||
secret = get_secret(secret_url)
|
|
||||||
return model.parse_raw(secret.value)
|
|
||||||
|
|
||||||
|
|
||||||
def delete_secret(secret_url: str) -> None:
|
|
||||||
(vault_url, secret_name) = parse_secret_url(secret_url)
|
|
||||||
keyvault_client = get_keyvault_client(vault_url)
|
|
||||||
keyvault_client.begin_delete_secret(secret_name)
|
|
||||||
|
|
||||||
|
|
||||||
def delete_remote_secret_data(data: SecretData) -> None:
|
|
||||||
if isinstance(data.secret, SecretAddress):
|
|
||||||
delete_secret(data.secret.url)
|
|
@ -1,52 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import datetime
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from onefuzztypes.models import TaskEvent as BASE_TASK_EVENT
|
|
||||||
from onefuzztypes.models import TaskEventSummary, WorkerEvent
|
|
||||||
|
|
||||||
from .orm import ORMMixin
|
|
||||||
|
|
||||||
|
|
||||||
class TaskEvent(BASE_TASK_EVENT, ORMMixin):
|
|
||||||
@classmethod
|
|
||||||
def get_summary(cls, task_id: UUID) -> List[TaskEventSummary]:
|
|
||||||
events = cls.search(query={"task_id": [task_id]})
|
|
||||||
# handle None case of Optional[e.timestamp], which shouldn't happen
|
|
||||||
events.sort(key=lambda e: e.timestamp or datetime.datetime.max)
|
|
||||||
|
|
||||||
return [
|
|
||||||
TaskEventSummary(
|
|
||||||
timestamp=e.timestamp,
|
|
||||||
event_data=get_event_data(e.event_data),
|
|
||||||
event_type=get_event_type(e.event_data),
|
|
||||||
)
|
|
||||||
for e in events
|
|
||||||
]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def key_fields(cls) -> Tuple[str, Optional[str]]:
|
|
||||||
return ("task_id", None)
|
|
||||||
|
|
||||||
|
|
||||||
def get_event_data(event: WorkerEvent) -> str:
|
|
||||||
if event.done:
|
|
||||||
return "exit status: %s" % event.done.exit_status
|
|
||||||
elif event.running:
|
|
||||||
return ""
|
|
||||||
else:
|
|
||||||
return "Unrecognized event: %s" % event
|
|
||||||
|
|
||||||
|
|
||||||
def get_event_type(event: WorkerEvent) -> str:
|
|
||||||
if event.done:
|
|
||||||
return type(event.done).__name__
|
|
||||||
elif event.running:
|
|
||||||
return type(event.running).__name__
|
|
||||||
else:
|
|
||||||
return "Unrecognized event: %s" % event
|
|
@ -1,466 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import pathlib
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
from onefuzztypes.enums import Compare, ContainerPermission, ContainerType, TaskFeature
|
|
||||||
from onefuzztypes.models import Job, Task, TaskConfig, TaskDefinition, TaskUnitConfig
|
|
||||||
from onefuzztypes.primitives import Container
|
|
||||||
|
|
||||||
from ..azure.containers import (
|
|
||||||
add_container_sas_url,
|
|
||||||
blob_exists,
|
|
||||||
container_exists,
|
|
||||||
get_container_sas_url,
|
|
||||||
)
|
|
||||||
from ..azure.creds import get_instance_id
|
|
||||||
from ..azure.queue import get_queue_sas
|
|
||||||
from ..azure.storage import StorageType
|
|
||||||
from ..workers.pools import Pool
|
|
||||||
from .defs import TASK_DEFINITIONS
|
|
||||||
|
|
||||||
LOGGER = logging.getLogger("onefuzz")
|
|
||||||
|
|
||||||
|
|
||||||
def get_input_container_queues(config: TaskConfig) -> Optional[List[str]]: # tasks.Task
|
|
||||||
|
|
||||||
if config.task.type not in TASK_DEFINITIONS:
|
|
||||||
raise TaskConfigError("unsupported task type: %s" % config.task.type.name)
|
|
||||||
|
|
||||||
container_type = TASK_DEFINITIONS[config.task.type].monitor_queue
|
|
||||||
if container_type:
|
|
||||||
return [x.name for x in config.containers if x.type == container_type]
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def check_val(compare: Compare, expected: int, actual: int) -> bool:
|
|
||||||
if compare == Compare.Equal:
|
|
||||||
return expected == actual
|
|
||||||
|
|
||||||
if compare == Compare.AtLeast:
|
|
||||||
return expected <= actual
|
|
||||||
|
|
||||||
if compare == Compare.AtMost:
|
|
||||||
return expected >= actual
|
|
||||||
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
def check_container(
|
|
||||||
compare: Compare,
|
|
||||||
expected: int,
|
|
||||||
container_type: ContainerType,
|
|
||||||
containers: Dict[ContainerType, List[Container]],
|
|
||||||
) -> None:
|
|
||||||
actual = len(containers.get(container_type, []))
|
|
||||||
if not check_val(compare, expected, actual):
|
|
||||||
raise TaskConfigError(
|
|
||||||
"container type %s: expected %s %d, got %d"
|
|
||||||
% (container_type.name, compare.name, expected, actual)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def check_containers(definition: TaskDefinition, config: TaskConfig) -> None:
|
|
||||||
checked = set()
|
|
||||||
|
|
||||||
containers: Dict[ContainerType, List[Container]] = {}
|
|
||||||
for container in config.containers:
|
|
||||||
if container.name not in checked:
|
|
||||||
if not container_exists(container.name, StorageType.corpus):
|
|
||||||
raise TaskConfigError("missing container: %s" % container.name)
|
|
||||||
checked.add(container.name)
|
|
||||||
|
|
||||||
if container.type not in containers:
|
|
||||||
containers[container.type] = []
|
|
||||||
containers[container.type].append(container.name)
|
|
||||||
|
|
||||||
for container_def in definition.containers:
|
|
||||||
check_container(
|
|
||||||
container_def.compare, container_def.value, container_def.type, containers
|
|
||||||
)
|
|
||||||
|
|
||||||
for container_type in containers:
|
|
||||||
if container_type not in [x.type for x in definition.containers]:
|
|
||||||
raise TaskConfigError(
|
|
||||||
"unsupported container type for this task: %s" % container_type.name
|
|
||||||
)
|
|
||||||
|
|
||||||
if definition.monitor_queue:
|
|
||||||
if definition.monitor_queue not in [x.type for x in definition.containers]:
|
|
||||||
raise TaskConfigError(
|
|
||||||
"unable to monitor container type as it is not used by this task: %s"
|
|
||||||
% definition.monitor_queue.name
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def check_target_exe(config: TaskConfig, definition: TaskDefinition) -> None:
|
|
||||||
if config.task.target_exe is None:
|
|
||||||
if TaskFeature.target_exe in definition.features:
|
|
||||||
raise TaskConfigError("missing target_exe")
|
|
||||||
|
|
||||||
if TaskFeature.target_exe_optional in definition.features:
|
|
||||||
return
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
# User-submitted paths must be relative to the setup directory that contains them.
|
|
||||||
# They also must be normalized, and exclude special filesystem path elements.
|
|
||||||
#
|
|
||||||
# For example, accessing the blob store path "./foo" generates an exception, but
|
|
||||||
# "foo" and "foo/bar" do not.
|
|
||||||
if not is_valid_blob_name(config.task.target_exe):
|
|
||||||
raise TaskConfigError("target_exe must be a canonicalized relative path")
|
|
||||||
|
|
||||||
container = [x for x in config.containers if x.type == ContainerType.setup][0]
|
|
||||||
if not blob_exists(container.name, config.task.target_exe, StorageType.corpus):
|
|
||||||
err = "target_exe `%s` does not exist in the setup container `%s`" % (
|
|
||||||
config.task.target_exe,
|
|
||||||
container.name,
|
|
||||||
)
|
|
||||||
LOGGER.warning(err)
|
|
||||||
|
|
||||||
|
|
||||||
# Azure Blob Storage uses a flat scheme, and has no true directory hierarchy. Forward
|
|
||||||
# slashes are used to delimit a _virtual_ directory structure.
|
|
||||||
def is_valid_blob_name(blob_name: str) -> bool:
|
|
||||||
# https://docs.microsoft.com/en-us/rest/api/storageservices/naming-and-referencing-containers--blobs--and-metadata#blob-names
|
|
||||||
MIN_LENGTH = 1
|
|
||||||
MAX_LENGTH = 1024 # Inclusive
|
|
||||||
MAX_PATH_SEGMENTS = 254
|
|
||||||
|
|
||||||
length = len(blob_name)
|
|
||||||
|
|
||||||
# No leading/trailing whitespace.
|
|
||||||
if blob_name != blob_name.strip():
|
|
||||||
return False
|
|
||||||
|
|
||||||
if length < MIN_LENGTH:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if length > MAX_LENGTH:
|
|
||||||
return False
|
|
||||||
|
|
||||||
path = pathlib.PurePosixPath(blob_name)
|
|
||||||
|
|
||||||
if len(path.parts) > MAX_PATH_SEGMENTS:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# No path segment should end with a dot (`.`).
|
|
||||||
for part in path.parts:
|
|
||||||
if part.endswith("."):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Reject absolute paths to avoid confusion.
|
|
||||||
if path.is_absolute():
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Reject paths with special relative filesystem entries.
|
|
||||||
if "." in path.parts:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if ".." in path.parts:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Will not have a leading `.`, even if `blob_name` does.
|
|
||||||
normalized = path.as_posix()
|
|
||||||
|
|
||||||
return blob_name == normalized
|
|
||||||
|
|
||||||
|
|
||||||
def target_uses_input(config: TaskConfig) -> bool:
|
|
||||||
if config.task.target_options is not None:
|
|
||||||
for option in config.task.target_options:
|
|
||||||
if "{input}" in option:
|
|
||||||
return True
|
|
||||||
if config.task.target_env is not None:
|
|
||||||
for value in config.task.target_env.values():
|
|
||||||
if "{input}" in value:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def check_config(config: TaskConfig) -> None:
|
|
||||||
if config.task.type not in TASK_DEFINITIONS:
|
|
||||||
raise TaskConfigError("unsupported task type: %s" % config.task.type.name)
|
|
||||||
|
|
||||||
if config.vm is not None and config.pool is not None:
|
|
||||||
raise TaskConfigError("either the vm or pool must be specified, but not both")
|
|
||||||
|
|
||||||
definition = TASK_DEFINITIONS[config.task.type]
|
|
||||||
|
|
||||||
check_containers(definition, config)
|
|
||||||
|
|
||||||
if (
|
|
||||||
TaskFeature.supervisor_exe in definition.features
|
|
||||||
and not config.task.supervisor_exe
|
|
||||||
):
|
|
||||||
err = "missing supervisor_exe"
|
|
||||||
LOGGER.error(err)
|
|
||||||
raise TaskConfigError("missing supervisor_exe")
|
|
||||||
|
|
||||||
if (
|
|
||||||
TaskFeature.target_must_use_input in definition.features
|
|
||||||
and not target_uses_input(config)
|
|
||||||
):
|
|
||||||
raise TaskConfigError("{input} must be used in target_env or target_options")
|
|
||||||
|
|
||||||
if config.vm:
|
|
||||||
err = "specifying task config vm is no longer supported"
|
|
||||||
raise TaskConfigError(err)
|
|
||||||
|
|
||||||
if not config.pool:
|
|
||||||
raise TaskConfigError("pool must be specified")
|
|
||||||
|
|
||||||
if not check_val(definition.vm.compare, definition.vm.value, config.pool.count):
|
|
||||||
err = "invalid vm count: expected %s %d, got %s" % (
|
|
||||||
definition.vm.compare,
|
|
||||||
definition.vm.value,
|
|
||||||
config.pool.count,
|
|
||||||
)
|
|
||||||
LOGGER.error(err)
|
|
||||||
raise TaskConfigError(err)
|
|
||||||
|
|
||||||
pool = Pool.get_by_name(config.pool.pool_name)
|
|
||||||
if not isinstance(pool, Pool):
|
|
||||||
raise TaskConfigError(f"invalid pool: {config.pool.pool_name}")
|
|
||||||
|
|
||||||
check_target_exe(config, definition)
|
|
||||||
|
|
||||||
if TaskFeature.generator_exe in definition.features:
|
|
||||||
container = [x for x in config.containers if x.type == ContainerType.tools][0]
|
|
||||||
if not config.task.generator_exe:
|
|
||||||
raise TaskConfigError("generator_exe is not defined")
|
|
||||||
|
|
||||||
tools_paths = ["{tools_dir}/", "{tools_dir}\\"]
|
|
||||||
for tool_path in tools_paths:
|
|
||||||
if config.task.generator_exe.startswith(tool_path):
|
|
||||||
generator = config.task.generator_exe.replace(tool_path, "")
|
|
||||||
if not blob_exists(container.name, generator, StorageType.corpus):
|
|
||||||
err = (
|
|
||||||
"generator_exe `%s` does not exist in the tools container `%s`"
|
|
||||||
% (
|
|
||||||
config.task.generator_exe,
|
|
||||||
container.name,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
LOGGER.error(err)
|
|
||||||
raise TaskConfigError(err)
|
|
||||||
|
|
||||||
if TaskFeature.stats_file in definition.features:
|
|
||||||
if config.task.stats_file is not None and config.task.stats_format is None:
|
|
||||||
err = "using a stats_file requires a stats_format"
|
|
||||||
LOGGER.error(err)
|
|
||||||
raise TaskConfigError(err)
|
|
||||||
|
|
||||||
|
|
||||||
def build_task_config(job: Job, task: Task) -> TaskUnitConfig:
|
|
||||||
job_id = job.job_id
|
|
||||||
task_id = task.task_id
|
|
||||||
task_config = task.config
|
|
||||||
|
|
||||||
if task_config.task.type not in TASK_DEFINITIONS:
|
|
||||||
raise TaskConfigError("unsupported task type: %s" % task_config.task.type.name)
|
|
||||||
|
|
||||||
definition = TASK_DEFINITIONS[task_config.task.type]
|
|
||||||
|
|
||||||
config = TaskUnitConfig(
|
|
||||||
job_id=job_id,
|
|
||||||
task_id=task_id,
|
|
||||||
task_type=task_config.task.type,
|
|
||||||
instance_telemetry_key=os.environ.get("APPINSIGHTS_INSTRUMENTATIONKEY"),
|
|
||||||
microsoft_telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"),
|
|
||||||
heartbeat_queue=get_queue_sas(
|
|
||||||
"task-heartbeat",
|
|
||||||
StorageType.config,
|
|
||||||
add=True,
|
|
||||||
),
|
|
||||||
instance_id=get_instance_id(),
|
|
||||||
)
|
|
||||||
|
|
||||||
if job.config.logs:
|
|
||||||
config.logs = add_container_sas_url(job.config.logs)
|
|
||||||
else:
|
|
||||||
LOGGER.warning("Missing log container: job_id %s, task_id %s", job_id, task_id)
|
|
||||||
|
|
||||||
if definition.monitor_queue:
|
|
||||||
config.input_queue = get_queue_sas(
|
|
||||||
task_id,
|
|
||||||
StorageType.corpus,
|
|
||||||
add=True,
|
|
||||||
read=True,
|
|
||||||
update=True,
|
|
||||||
process=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
for container_def in definition.containers:
|
|
||||||
if container_def.type == ContainerType.setup:
|
|
||||||
continue
|
|
||||||
|
|
||||||
containers = []
|
|
||||||
for (i, container) in enumerate(task_config.containers):
|
|
||||||
if container.type != container_def.type:
|
|
||||||
continue
|
|
||||||
|
|
||||||
containers.append(
|
|
||||||
{
|
|
||||||
"path": "_".join(["task", container_def.type.name, str(i)]),
|
|
||||||
"url": get_container_sas_url(
|
|
||||||
container.name,
|
|
||||||
StorageType.corpus,
|
|
||||||
read=ContainerPermission.Read in container_def.permissions,
|
|
||||||
write=ContainerPermission.Write in container_def.permissions,
|
|
||||||
delete=ContainerPermission.Delete in container_def.permissions,
|
|
||||||
list_=ContainerPermission.List in container_def.permissions,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if not containers:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if (
|
|
||||||
container_def.compare in [Compare.Equal, Compare.AtMost]
|
|
||||||
and container_def.value == 1
|
|
||||||
):
|
|
||||||
setattr(config, container_def.type.name, containers[0])
|
|
||||||
else:
|
|
||||||
setattr(config, container_def.type.name, containers)
|
|
||||||
|
|
||||||
EMPTY_DICT: Dict[str, str] = {}
|
|
||||||
EMPTY_LIST: List[str] = []
|
|
||||||
|
|
||||||
if TaskFeature.supervisor_exe in definition.features:
|
|
||||||
config.supervisor_exe = task_config.task.supervisor_exe
|
|
||||||
|
|
||||||
if TaskFeature.supervisor_env in definition.features:
|
|
||||||
config.supervisor_env = task_config.task.supervisor_env or EMPTY_DICT
|
|
||||||
|
|
||||||
if TaskFeature.supervisor_options in definition.features:
|
|
||||||
config.supervisor_options = task_config.task.supervisor_options or EMPTY_LIST
|
|
||||||
|
|
||||||
if TaskFeature.supervisor_input_marker in definition.features:
|
|
||||||
config.supervisor_input_marker = task_config.task.supervisor_input_marker
|
|
||||||
|
|
||||||
if TaskFeature.target_exe in definition.features:
|
|
||||||
config.target_exe = task_config.task.target_exe
|
|
||||||
|
|
||||||
if (
|
|
||||||
TaskFeature.target_exe_optional in definition.features
|
|
||||||
and task_config.task.target_exe
|
|
||||||
):
|
|
||||||
config.target_exe = task_config.task.target_exe
|
|
||||||
|
|
||||||
if TaskFeature.target_env in definition.features:
|
|
||||||
config.target_env = task_config.task.target_env or EMPTY_DICT
|
|
||||||
|
|
||||||
if TaskFeature.target_options in definition.features:
|
|
||||||
config.target_options = task_config.task.target_options or EMPTY_LIST
|
|
||||||
|
|
||||||
if TaskFeature.target_options_merge in definition.features:
|
|
||||||
config.target_options_merge = task_config.task.target_options_merge or False
|
|
||||||
|
|
||||||
if TaskFeature.target_workers in definition.features:
|
|
||||||
config.target_workers = task_config.task.target_workers
|
|
||||||
|
|
||||||
if TaskFeature.rename_output in definition.features:
|
|
||||||
config.rename_output = task_config.task.rename_output or False
|
|
||||||
|
|
||||||
if TaskFeature.generator_exe in definition.features:
|
|
||||||
config.generator_exe = task_config.task.generator_exe
|
|
||||||
|
|
||||||
if TaskFeature.generator_env in definition.features:
|
|
||||||
config.generator_env = task_config.task.generator_env or EMPTY_DICT
|
|
||||||
|
|
||||||
if TaskFeature.generator_options in definition.features:
|
|
||||||
config.generator_options = task_config.task.generator_options or EMPTY_LIST
|
|
||||||
|
|
||||||
if (
|
|
||||||
TaskFeature.wait_for_files in definition.features
|
|
||||||
and task_config.task.wait_for_files
|
|
||||||
):
|
|
||||||
config.wait_for_files = task_config.task.wait_for_files.name
|
|
||||||
|
|
||||||
if TaskFeature.analyzer_exe in definition.features:
|
|
||||||
config.analyzer_exe = task_config.task.analyzer_exe
|
|
||||||
|
|
||||||
if TaskFeature.analyzer_options in definition.features:
|
|
||||||
config.analyzer_options = task_config.task.analyzer_options or EMPTY_LIST
|
|
||||||
|
|
||||||
if TaskFeature.analyzer_env in definition.features:
|
|
||||||
config.analyzer_env = task_config.task.analyzer_env or EMPTY_DICT
|
|
||||||
|
|
||||||
if TaskFeature.stats_file in definition.features:
|
|
||||||
config.stats_file = task_config.task.stats_file
|
|
||||||
config.stats_format = task_config.task.stats_format
|
|
||||||
|
|
||||||
if TaskFeature.target_timeout in definition.features:
|
|
||||||
config.target_timeout = task_config.task.target_timeout
|
|
||||||
|
|
||||||
if TaskFeature.check_asan_log in definition.features:
|
|
||||||
config.check_asan_log = task_config.task.check_asan_log
|
|
||||||
|
|
||||||
if TaskFeature.check_debugger in definition.features:
|
|
||||||
config.check_debugger = task_config.task.check_debugger
|
|
||||||
|
|
||||||
if TaskFeature.check_retry_count in definition.features:
|
|
||||||
config.check_retry_count = task_config.task.check_retry_count or 0
|
|
||||||
|
|
||||||
if TaskFeature.ensemble_sync_delay in definition.features:
|
|
||||||
config.ensemble_sync_delay = task_config.task.ensemble_sync_delay
|
|
||||||
|
|
||||||
if TaskFeature.check_fuzzer_help in definition.features:
|
|
||||||
config.check_fuzzer_help = (
|
|
||||||
task_config.task.check_fuzzer_help
|
|
||||||
if task_config.task.check_fuzzer_help is not None
|
|
||||||
else True
|
|
||||||
)
|
|
||||||
|
|
||||||
if TaskFeature.report_list in definition.features:
|
|
||||||
config.report_list = task_config.task.report_list
|
|
||||||
|
|
||||||
if TaskFeature.minimized_stack_depth in definition.features:
|
|
||||||
config.minimized_stack_depth = task_config.task.minimized_stack_depth
|
|
||||||
|
|
||||||
if TaskFeature.expect_crash_on_failure in definition.features:
|
|
||||||
config.expect_crash_on_failure = (
|
|
||||||
task_config.task.expect_crash_on_failure
|
|
||||||
if task_config.task.expect_crash_on_failure is not None
|
|
||||||
else True
|
|
||||||
)
|
|
||||||
|
|
||||||
if TaskFeature.coverage_filter in definition.features:
|
|
||||||
coverage_filter = task_config.task.coverage_filter
|
|
||||||
|
|
||||||
if coverage_filter is not None:
|
|
||||||
config.coverage_filter = coverage_filter
|
|
||||||
|
|
||||||
if TaskFeature.target_assembly in definition.features:
|
|
||||||
config.target_assembly = task_config.task.target_assembly
|
|
||||||
|
|
||||||
if TaskFeature.target_class in definition.features:
|
|
||||||
config.target_class = task_config.task.target_class
|
|
||||||
|
|
||||||
if TaskFeature.target_method in definition.features:
|
|
||||||
config.target_method = task_config.task.target_method
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def get_setup_container(config: TaskConfig) -> Container:
|
|
||||||
for container in config.containers:
|
|
||||||
if container.type == ContainerType.setup:
|
|
||||||
return container.name
|
|
||||||
|
|
||||||
raise TaskConfigError(
|
|
||||||
"task missing setup container: task_type = %s" % config.task.type
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TaskConfigError(Exception):
|
|
||||||
pass
|
|
@ -1,707 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
from onefuzztypes.enums import (
|
|
||||||
Compare,
|
|
||||||
ContainerPermission,
|
|
||||||
ContainerType,
|
|
||||||
TaskFeature,
|
|
||||||
TaskType,
|
|
||||||
)
|
|
||||||
from onefuzztypes.models import ContainerDefinition, TaskDefinition, VmDefinition
|
|
||||||
|
|
||||||
# all tasks are required to have a 'setup' container
|
|
||||||
TASK_DEFINITIONS = {
|
|
||||||
TaskType.coverage: TaskDefinition(
|
|
||||||
features=[
|
|
||||||
TaskFeature.target_exe,
|
|
||||||
TaskFeature.target_env,
|
|
||||||
TaskFeature.target_options,
|
|
||||||
TaskFeature.target_timeout,
|
|
||||||
TaskFeature.coverage_filter,
|
|
||||||
TaskFeature.target_must_use_input,
|
|
||||||
],
|
|
||||||
vm=VmDefinition(compare=Compare.Equal, value=1),
|
|
||||||
containers=[
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.setup,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.readonly_inputs,
|
|
||||||
compare=Compare.AtLeast,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.coverage,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[
|
|
||||||
ContainerPermission.List,
|
|
||||||
ContainerPermission.Read,
|
|
||||||
ContainerPermission.Write,
|
|
||||||
],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
monitor_queue=ContainerType.readonly_inputs,
|
|
||||||
),
|
|
||||||
TaskType.dotnet_coverage: TaskDefinition(
|
|
||||||
features=[
|
|
||||||
TaskFeature.target_exe,
|
|
||||||
TaskFeature.target_env,
|
|
||||||
TaskFeature.target_options,
|
|
||||||
TaskFeature.target_timeout,
|
|
||||||
TaskFeature.coverage_filter,
|
|
||||||
TaskFeature.target_must_use_input,
|
|
||||||
],
|
|
||||||
vm=VmDefinition(compare=Compare.Equal, value=1),
|
|
||||||
containers=[
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.setup,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.readonly_inputs,
|
|
||||||
compare=Compare.AtLeast,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.coverage,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[
|
|
||||||
ContainerPermission.List,
|
|
||||||
ContainerPermission.Read,
|
|
||||||
ContainerPermission.Write,
|
|
||||||
],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.tools,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
monitor_queue=ContainerType.readonly_inputs,
|
|
||||||
),
|
|
||||||
TaskType.dotnet_crash_report: TaskDefinition(
|
|
||||||
features=[
|
|
||||||
TaskFeature.target_exe,
|
|
||||||
TaskFeature.target_env,
|
|
||||||
TaskFeature.target_options,
|
|
||||||
TaskFeature.target_timeout,
|
|
||||||
TaskFeature.check_asan_log,
|
|
||||||
TaskFeature.check_debugger,
|
|
||||||
TaskFeature.check_retry_count,
|
|
||||||
TaskFeature.minimized_stack_depth,
|
|
||||||
],
|
|
||||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
|
||||||
containers=[
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.setup,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.crashes,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.reports,
|
|
||||||
compare=Compare.AtMost,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Write],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.unique_reports,
|
|
||||||
compare=Compare.AtMost,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Write],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.no_repro,
|
|
||||||
compare=Compare.AtMost,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Write],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.tools,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
monitor_queue=ContainerType.crashes,
|
|
||||||
),
|
|
||||||
TaskType.libfuzzer_dotnet_fuzz: TaskDefinition(
|
|
||||||
features=[
|
|
||||||
TaskFeature.target_exe,
|
|
||||||
TaskFeature.target_env,
|
|
||||||
TaskFeature.target_options,
|
|
||||||
TaskFeature.target_workers,
|
|
||||||
TaskFeature.ensemble_sync_delay,
|
|
||||||
TaskFeature.check_fuzzer_help,
|
|
||||||
TaskFeature.expect_crash_on_failure,
|
|
||||||
TaskFeature.target_assembly,
|
|
||||||
TaskFeature.target_class,
|
|
||||||
TaskFeature.target_method,
|
|
||||||
],
|
|
||||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
|
||||||
containers=[
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.setup,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.crashes,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Write],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.inputs,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[
|
|
||||||
ContainerPermission.Write,
|
|
||||||
ContainerPermission.Read,
|
|
||||||
ContainerPermission.List,
|
|
||||||
],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.readonly_inputs,
|
|
||||||
compare=Compare.AtLeast,
|
|
||||||
value=0,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.tools,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
monitor_queue=None,
|
|
||||||
),
|
|
||||||
TaskType.generic_analysis: TaskDefinition(
|
|
||||||
features=[
|
|
||||||
TaskFeature.target_exe,
|
|
||||||
TaskFeature.target_options,
|
|
||||||
TaskFeature.analyzer_exe,
|
|
||||||
TaskFeature.analyzer_env,
|
|
||||||
TaskFeature.analyzer_options,
|
|
||||||
],
|
|
||||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
|
||||||
containers=[
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.setup,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.analysis,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[
|
|
||||||
ContainerPermission.Write,
|
|
||||||
ContainerPermission.Read,
|
|
||||||
ContainerPermission.List,
|
|
||||||
],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.crashes,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.tools,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
monitor_queue=ContainerType.crashes,
|
|
||||||
),
|
|
||||||
TaskType.libfuzzer_fuzz: TaskDefinition(
|
|
||||||
features=[
|
|
||||||
TaskFeature.target_exe,
|
|
||||||
TaskFeature.target_env,
|
|
||||||
TaskFeature.target_options,
|
|
||||||
TaskFeature.target_workers,
|
|
||||||
TaskFeature.ensemble_sync_delay,
|
|
||||||
TaskFeature.check_fuzzer_help,
|
|
||||||
TaskFeature.expect_crash_on_failure,
|
|
||||||
],
|
|
||||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
|
||||||
containers=[
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.setup,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.crashes,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Write],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.inputs,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[
|
|
||||||
ContainerPermission.Write,
|
|
||||||
ContainerPermission.Read,
|
|
||||||
ContainerPermission.List,
|
|
||||||
],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.readonly_inputs,
|
|
||||||
compare=Compare.AtLeast,
|
|
||||||
value=0,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
monitor_queue=None,
|
|
||||||
),
|
|
||||||
TaskType.libfuzzer_crash_report: TaskDefinition(
|
|
||||||
features=[
|
|
||||||
TaskFeature.target_exe,
|
|
||||||
TaskFeature.target_env,
|
|
||||||
TaskFeature.target_options,
|
|
||||||
TaskFeature.target_timeout,
|
|
||||||
TaskFeature.check_retry_count,
|
|
||||||
TaskFeature.check_fuzzer_help,
|
|
||||||
TaskFeature.minimized_stack_depth,
|
|
||||||
],
|
|
||||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
|
||||||
containers=[
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.setup,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.crashes,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.reports,
|
|
||||||
compare=Compare.AtMost,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Write],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.unique_reports,
|
|
||||||
compare=Compare.AtMost,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Write],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.no_repro,
|
|
||||||
compare=Compare.AtMost,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Write],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
monitor_queue=ContainerType.crashes,
|
|
||||||
),
|
|
||||||
TaskType.libfuzzer_merge: TaskDefinition(
|
|
||||||
features=[
|
|
||||||
TaskFeature.target_exe,
|
|
||||||
TaskFeature.target_env,
|
|
||||||
TaskFeature.target_options,
|
|
||||||
TaskFeature.check_fuzzer_help,
|
|
||||||
],
|
|
||||||
vm=VmDefinition(compare=Compare.Equal, value=1),
|
|
||||||
containers=[
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.setup,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.unique_inputs,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[
|
|
||||||
ContainerPermission.List,
|
|
||||||
ContainerPermission.Read,
|
|
||||||
ContainerPermission.Write,
|
|
||||||
],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.inputs,
|
|
||||||
compare=Compare.AtLeast,
|
|
||||||
value=0,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
monitor_queue=None,
|
|
||||||
),
|
|
||||||
TaskType.generic_supervisor: TaskDefinition(
|
|
||||||
features=[
|
|
||||||
TaskFeature.target_exe,
|
|
||||||
TaskFeature.target_options,
|
|
||||||
TaskFeature.supervisor_exe,
|
|
||||||
TaskFeature.supervisor_env,
|
|
||||||
TaskFeature.supervisor_options,
|
|
||||||
TaskFeature.supervisor_input_marker,
|
|
||||||
TaskFeature.wait_for_files,
|
|
||||||
TaskFeature.stats_file,
|
|
||||||
TaskFeature.ensemble_sync_delay,
|
|
||||||
],
|
|
||||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
|
||||||
containers=[
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.setup,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.tools,
|
|
||||||
compare=Compare.AtMost,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.crashes,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Write],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.inputs,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[
|
|
||||||
ContainerPermission.Write,
|
|
||||||
ContainerPermission.Read,
|
|
||||||
ContainerPermission.List,
|
|
||||||
],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.unique_reports,
|
|
||||||
compare=Compare.AtMost,
|
|
||||||
value=1,
|
|
||||||
permissions=[
|
|
||||||
ContainerPermission.Write,
|
|
||||||
ContainerPermission.Read,
|
|
||||||
ContainerPermission.List,
|
|
||||||
],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.reports,
|
|
||||||
compare=Compare.AtMost,
|
|
||||||
value=1,
|
|
||||||
permissions=[
|
|
||||||
ContainerPermission.Write,
|
|
||||||
ContainerPermission.Read,
|
|
||||||
ContainerPermission.List,
|
|
||||||
],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.no_repro,
|
|
||||||
compare=Compare.AtMost,
|
|
||||||
value=1,
|
|
||||||
permissions=[
|
|
||||||
ContainerPermission.Write,
|
|
||||||
ContainerPermission.Read,
|
|
||||||
ContainerPermission.List,
|
|
||||||
],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.coverage,
|
|
||||||
compare=Compare.AtMost,
|
|
||||||
value=1,
|
|
||||||
permissions=[
|
|
||||||
ContainerPermission.Write,
|
|
||||||
ContainerPermission.Read,
|
|
||||||
ContainerPermission.List,
|
|
||||||
],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
monitor_queue=None,
|
|
||||||
),
|
|
||||||
TaskType.generic_merge: TaskDefinition(
|
|
||||||
features=[
|
|
||||||
TaskFeature.target_exe,
|
|
||||||
TaskFeature.target_options,
|
|
||||||
TaskFeature.supervisor_exe,
|
|
||||||
TaskFeature.supervisor_env,
|
|
||||||
TaskFeature.supervisor_options,
|
|
||||||
TaskFeature.supervisor_input_marker,
|
|
||||||
TaskFeature.stats_file,
|
|
||||||
TaskFeature.preserve_existing_outputs,
|
|
||||||
],
|
|
||||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
|
||||||
containers=[
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.setup,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.tools,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.readonly_inputs,
|
|
||||||
compare=Compare.AtLeast,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.inputs,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Write, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
monitor_queue=None,
|
|
||||||
),
|
|
||||||
TaskType.generic_generator: TaskDefinition(
|
|
||||||
features=[
|
|
||||||
TaskFeature.generator_exe,
|
|
||||||
TaskFeature.generator_env,
|
|
||||||
TaskFeature.generator_options,
|
|
||||||
TaskFeature.target_exe,
|
|
||||||
TaskFeature.target_env,
|
|
||||||
TaskFeature.target_options,
|
|
||||||
TaskFeature.rename_output,
|
|
||||||
TaskFeature.target_timeout,
|
|
||||||
TaskFeature.check_asan_log,
|
|
||||||
TaskFeature.check_debugger,
|
|
||||||
TaskFeature.check_retry_count,
|
|
||||||
TaskFeature.ensemble_sync_delay,
|
|
||||||
TaskFeature.target_must_use_input,
|
|
||||||
],
|
|
||||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
|
||||||
containers=[
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.setup,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.tools,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.crashes,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Write],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.readonly_inputs,
|
|
||||||
compare=Compare.AtLeast,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
monitor_queue=None,
|
|
||||||
),
|
|
||||||
TaskType.generic_crash_report: TaskDefinition(
|
|
||||||
features=[
|
|
||||||
TaskFeature.target_exe,
|
|
||||||
TaskFeature.target_env,
|
|
||||||
TaskFeature.target_options,
|
|
||||||
TaskFeature.target_timeout,
|
|
||||||
TaskFeature.check_asan_log,
|
|
||||||
TaskFeature.check_debugger,
|
|
||||||
TaskFeature.check_retry_count,
|
|
||||||
TaskFeature.minimized_stack_depth,
|
|
||||||
],
|
|
||||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
|
||||||
containers=[
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.setup,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.crashes,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.reports,
|
|
||||||
compare=Compare.AtMost,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Write],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.unique_reports,
|
|
||||||
compare=Compare.AtMost,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Write],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.no_repro,
|
|
||||||
compare=Compare.AtMost,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Write],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
monitor_queue=ContainerType.crashes,
|
|
||||||
),
|
|
||||||
TaskType.generic_regression: TaskDefinition(
|
|
||||||
features=[
|
|
||||||
TaskFeature.target_exe,
|
|
||||||
TaskFeature.target_env,
|
|
||||||
TaskFeature.target_options,
|
|
||||||
TaskFeature.target_timeout,
|
|
||||||
TaskFeature.check_asan_log,
|
|
||||||
TaskFeature.check_debugger,
|
|
||||||
TaskFeature.check_retry_count,
|
|
||||||
TaskFeature.report_list,
|
|
||||||
TaskFeature.minimized_stack_depth,
|
|
||||||
],
|
|
||||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
|
||||||
containers=[
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.setup,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.regression_reports,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[
|
|
||||||
ContainerPermission.Write,
|
|
||||||
ContainerPermission.Read,
|
|
||||||
ContainerPermission.List,
|
|
||||||
],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.crashes,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.reports,
|
|
||||||
compare=Compare.AtMost,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.unique_reports,
|
|
||||||
compare=Compare.AtMost,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.no_repro,
|
|
||||||
compare=Compare.AtMost,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.readonly_inputs,
|
|
||||||
compare=Compare.AtMost,
|
|
||||||
value=1,
|
|
||||||
permissions=[
|
|
||||||
ContainerPermission.Read,
|
|
||||||
ContainerPermission.List,
|
|
||||||
],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
TaskType.libfuzzer_regression: TaskDefinition(
|
|
||||||
features=[
|
|
||||||
TaskFeature.target_exe,
|
|
||||||
TaskFeature.target_env,
|
|
||||||
TaskFeature.target_options,
|
|
||||||
TaskFeature.target_timeout,
|
|
||||||
TaskFeature.check_fuzzer_help,
|
|
||||||
TaskFeature.check_retry_count,
|
|
||||||
TaskFeature.report_list,
|
|
||||||
TaskFeature.minimized_stack_depth,
|
|
||||||
],
|
|
||||||
vm=VmDefinition(compare=Compare.AtLeast, value=1),
|
|
||||||
containers=[
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.setup,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.regression_reports,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[
|
|
||||||
ContainerPermission.Write,
|
|
||||||
ContainerPermission.Read,
|
|
||||||
ContainerPermission.List,
|
|
||||||
],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.crashes,
|
|
||||||
compare=Compare.Equal,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.unique_reports,
|
|
||||||
compare=Compare.AtMost,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.reports,
|
|
||||||
compare=Compare.AtMost,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.no_repro,
|
|
||||||
compare=Compare.AtMost,
|
|
||||||
value=1,
|
|
||||||
permissions=[ContainerPermission.Read, ContainerPermission.List],
|
|
||||||
),
|
|
||||||
ContainerDefinition(
|
|
||||||
type=ContainerType.readonly_inputs,
|
|
||||||
compare=Compare.AtMost,
|
|
||||||
value=1,
|
|
||||||
permissions=[
|
|
||||||
ContainerPermission.Read,
|
|
||||||
ContainerPermission.List,
|
|
||||||
],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
}
|
|
@ -1,360 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import List, Optional, Tuple, Union
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from onefuzztypes.enums import ErrorCode, TaskState
|
|
||||||
from onefuzztypes.events import (
|
|
||||||
EventTaskCreated,
|
|
||||||
EventTaskFailed,
|
|
||||||
EventTaskStateUpdated,
|
|
||||||
EventTaskStopped,
|
|
||||||
)
|
|
||||||
from onefuzztypes.models import Error
|
|
||||||
from onefuzztypes.models import Task as BASE_TASK
|
|
||||||
from onefuzztypes.models import TaskConfig, TaskVm, UserInfo
|
|
||||||
from onefuzztypes.primitives import PoolName
|
|
||||||
|
|
||||||
from ..azure.image import get_os
|
|
||||||
from ..azure.queue import create_queue, delete_queue
|
|
||||||
from ..azure.storage import StorageType
|
|
||||||
from ..events import send_event
|
|
||||||
from ..orm import MappingIntStrAny, ORMMixin, QueryFilter
|
|
||||||
from ..workers.nodes import Node, NodeTasks
|
|
||||||
from ..workers.pools import Pool
|
|
||||||
from ..workers.scalesets import Scaleset
|
|
||||||
|
|
||||||
|
|
||||||
class Task(BASE_TASK, ORMMixin):
|
|
||||||
def check_prereq_tasks(self) -> bool:
|
|
||||||
if self.config.prereq_tasks:
|
|
||||||
for task_id in self.config.prereq_tasks:
|
|
||||||
task = Task.get_by_task_id(task_id)
|
|
||||||
# if a prereq task fails, then mark this task as failed
|
|
||||||
if isinstance(task, Error):
|
|
||||||
self.mark_failed(task)
|
|
||||||
return False
|
|
||||||
|
|
||||||
if task.state not in task.state.has_started():
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(
|
|
||||||
cls, config: TaskConfig, job_id: UUID, user_info: UserInfo
|
|
||||||
) -> Union["Task", Error]:
|
|
||||||
if config.vm:
|
|
||||||
os = get_os(config.vm.region, config.vm.image)
|
|
||||||
if isinstance(os, Error):
|
|
||||||
return os
|
|
||||||
elif config.pool:
|
|
||||||
pool = Pool.get_by_name(config.pool.pool_name)
|
|
||||||
if isinstance(pool, Error):
|
|
||||||
return pool
|
|
||||||
os = pool.os
|
|
||||||
else:
|
|
||||||
raise Exception("task must have vm or pool")
|
|
||||||
task = cls(config=config, job_id=job_id, os=os, user_info=user_info)
|
|
||||||
task.save()
|
|
||||||
send_event(
|
|
||||||
EventTaskCreated(
|
|
||||||
job_id=task.job_id,
|
|
||||||
task_id=task.task_id,
|
|
||||||
config=config,
|
|
||||||
user_info=user_info,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
"created task. job_id:%s task_id:%s type:%s",
|
|
||||||
task.job_id,
|
|
||||||
task.task_id,
|
|
||||||
task.config.task.type.name,
|
|
||||||
)
|
|
||||||
return task
|
|
||||||
|
|
||||||
def save_exclude(self) -> Optional[MappingIntStrAny]:
|
|
||||||
return {"heartbeats": ...}
|
|
||||||
|
|
||||||
def is_ready(self) -> bool:
|
|
||||||
if self.config.prereq_tasks:
|
|
||||||
for prereq_id in self.config.prereq_tasks:
|
|
||||||
prereq = Task.get_by_task_id(prereq_id)
|
|
||||||
if isinstance(prereq, Error):
|
|
||||||
logging.info("task prereq has error: %s - %s", self.task_id, prereq)
|
|
||||||
self.mark_failed(prereq)
|
|
||||||
return False
|
|
||||||
if prereq.state != TaskState.running:
|
|
||||||
logging.info(
|
|
||||||
"task is waiting on prereq: %s - %s:",
|
|
||||||
self.task_id,
|
|
||||||
prereq.task_id,
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
# At current, the telemetry filter will generate something like this:
|
|
||||||
#
|
|
||||||
# {
|
|
||||||
# 'job_id': 'f4a20fd8-0dcc-4a4e-8804-6ef7df50c978',
|
|
||||||
# 'task_id': '835f7b3f-43ad-4718-b7e4-d506d9667b09',
|
|
||||||
# 'state': 'stopped',
|
|
||||||
# 'config': {
|
|
||||||
# 'task': {'type': 'coverage'},
|
|
||||||
# 'vm': {'count': 1}
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
def telemetry_include(self) -> Optional[MappingIntStrAny]:
|
|
||||||
return {
|
|
||||||
"job_id": ...,
|
|
||||||
"task_id": ...,
|
|
||||||
"state": ...,
|
|
||||||
"config": {"vm": {"count": ...}, "task": {"type": ...}},
|
|
||||||
}
|
|
||||||
|
|
||||||
def init(self) -> None:
|
|
||||||
create_queue(self.task_id, StorageType.corpus)
|
|
||||||
self.set_state(TaskState.waiting)
|
|
||||||
|
|
||||||
def stopping(self) -> None:
|
|
||||||
logging.info("stopping task: %s:%s", self.job_id, self.task_id)
|
|
||||||
Node.stop_task(self.task_id)
|
|
||||||
if not NodeTasks.get_nodes_by_task_id(self.task_id):
|
|
||||||
self.stopped()
|
|
||||||
|
|
||||||
def stopped(self) -> None:
|
|
||||||
self.set_state(TaskState.stopped)
|
|
||||||
delete_queue(str(self.task_id), StorageType.corpus)
|
|
||||||
|
|
||||||
# TODO: we need to 'unschedule' this task from the existing pools
|
|
||||||
from ..jobs import Job
|
|
||||||
|
|
||||||
job = Job.get(self.job_id)
|
|
||||||
if job:
|
|
||||||
job.stop_if_all_done()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def search_states(
|
|
||||||
cls, *, job_id: Optional[UUID] = None, states: Optional[List[TaskState]] = None
|
|
||||||
) -> List["Task"]:
|
|
||||||
query: QueryFilter = {}
|
|
||||||
if job_id:
|
|
||||||
query["job_id"] = [job_id]
|
|
||||||
if states:
|
|
||||||
query["state"] = states
|
|
||||||
|
|
||||||
return cls.search(query=query)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def search_expired(cls) -> List["Task"]:
|
|
||||||
time_filter = "end_time lt datetime'%s'" % datetime.utcnow().isoformat()
|
|
||||||
return cls.search(
|
|
||||||
query={"state": TaskState.available()}, raw_unchecked_filter=time_filter
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_by_task_id(cls, task_id: UUID) -> Union[Error, "Task"]:
|
|
||||||
tasks = cls.search(query={"task_id": [task_id]})
|
|
||||||
if not tasks:
|
|
||||||
return Error(code=ErrorCode.INVALID_REQUEST, errors=["unable to find task"])
|
|
||||||
|
|
||||||
if len(tasks) != 1:
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.INVALID_REQUEST, errors=["error identifying task"]
|
|
||||||
)
|
|
||||||
task = tasks[0]
|
|
||||||
return task
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_tasks_by_pool_name(cls, pool_name: PoolName) -> List["Task"]:
|
|
||||||
tasks = cls.search_states(states=TaskState.available())
|
|
||||||
if not tasks:
|
|
||||||
return []
|
|
||||||
|
|
||||||
pool_tasks = []
|
|
||||||
|
|
||||||
for task in tasks:
|
|
||||||
task_pool = task.get_pool()
|
|
||||||
if not task_pool:
|
|
||||||
continue
|
|
||||||
if pool_name == task_pool.name:
|
|
||||||
pool_tasks.append(task)
|
|
||||||
|
|
||||||
return pool_tasks
|
|
||||||
|
|
||||||
def mark_stopping(self) -> None:
|
|
||||||
if self.state in TaskState.shutting_down():
|
|
||||||
logging.debug(
|
|
||||||
"ignoring post-task stop calls to stop %s:%s", self.job_id, self.task_id
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.state not in TaskState.has_started():
|
|
||||||
self.mark_failed(
|
|
||||||
Error(code=ErrorCode.TASK_FAILED, errors=["task never started"])
|
|
||||||
)
|
|
||||||
|
|
||||||
self.set_state(TaskState.stopping)
|
|
||||||
|
|
||||||
def mark_failed(
|
|
||||||
self, error: Error, tasks_in_job: Optional[List["Task"]] = None
|
|
||||||
) -> None:
|
|
||||||
if self.state in TaskState.shutting_down():
|
|
||||||
logging.debug(
|
|
||||||
"ignoring post-task stop failures for %s:%s", self.job_id, self.task_id
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.error is not None:
|
|
||||||
logging.debug(
|
|
||||||
"ignoring additional task error %s:%s", self.job_id, self.task_id
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
logging.error("task failed %s:%s - %s", self.job_id, self.task_id, error)
|
|
||||||
|
|
||||||
self.error = error
|
|
||||||
self.set_state(TaskState.stopping)
|
|
||||||
|
|
||||||
self.mark_dependants_failed(tasks_in_job=tasks_in_job)
|
|
||||||
|
|
||||||
def mark_dependants_failed(
|
|
||||||
self, tasks_in_job: Optional[List["Task"]] = None
|
|
||||||
) -> None:
|
|
||||||
if tasks_in_job is None:
|
|
||||||
tasks_in_job = Task.search(query={"job_id": [self.job_id]})
|
|
||||||
|
|
||||||
for task in tasks_in_job:
|
|
||||||
if task.config.prereq_tasks:
|
|
||||||
if self.task_id in task.config.prereq_tasks:
|
|
||||||
task.mark_failed(
|
|
||||||
Error(
|
|
||||||
code=ErrorCode.TASK_FAILED,
|
|
||||||
errors=[
|
|
||||||
"prerequisite task failed. task_id:%s" % self.task_id
|
|
||||||
],
|
|
||||||
),
|
|
||||||
tasks_in_job,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_pool(self) -> Optional[Pool]:
|
|
||||||
if self.config.pool:
|
|
||||||
pool = Pool.get_by_name(self.config.pool.pool_name)
|
|
||||||
if isinstance(pool, Error):
|
|
||||||
logging.info(
|
|
||||||
"unable to schedule task to pool: %s - %s", self.task_id, pool
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
return pool
|
|
||||||
elif self.config.vm:
|
|
||||||
scalesets = Scaleset.search()
|
|
||||||
scalesets = [
|
|
||||||
x
|
|
||||||
for x in scalesets
|
|
||||||
if x.vm_sku == self.config.vm.sku and x.image == self.config.vm.image
|
|
||||||
]
|
|
||||||
for scaleset in scalesets:
|
|
||||||
pool = Pool.get_by_name(scaleset.pool_name)
|
|
||||||
if isinstance(pool, Error):
|
|
||||||
logging.info(
|
|
||||||
"unable to schedule task to pool: %s - %s",
|
|
||||||
self.task_id,
|
|
||||||
pool,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return pool
|
|
||||||
|
|
||||||
logging.warning(
|
|
||||||
"unable to find a scaleset that matches the task prereqs: %s",
|
|
||||||
self.task_id,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_repro_vm_config(self) -> Union[TaskVm, None]:
|
|
||||||
if self.config.vm:
|
|
||||||
return self.config.vm
|
|
||||||
|
|
||||||
if self.config.pool is None:
|
|
||||||
raise Exception("either pool or vm must be specified: %s" % self.task_id)
|
|
||||||
|
|
||||||
pool = Pool.get_by_name(self.config.pool.pool_name)
|
|
||||||
if isinstance(pool, Error):
|
|
||||||
logging.info("unable to find pool from task: %s", self.task_id)
|
|
||||||
return None
|
|
||||||
|
|
||||||
for scaleset in Scaleset.search_by_pool(self.config.pool.pool_name):
|
|
||||||
return TaskVm(
|
|
||||||
region=scaleset.region,
|
|
||||||
sku=scaleset.vm_sku,
|
|
||||||
image=scaleset.image,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.warning(
|
|
||||||
"no scalesets are defined for task: %s:%s", self.job_id, self.task_id
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def on_start(self) -> None:
|
|
||||||
# try to keep this effectively idempotent
|
|
||||||
|
|
||||||
if self.end_time is None:
|
|
||||||
self.end_time = datetime.utcnow() + timedelta(
|
|
||||||
hours=self.config.task.duration
|
|
||||||
)
|
|
||||||
|
|
||||||
from ..jobs import Job
|
|
||||||
|
|
||||||
job = Job.get(self.job_id)
|
|
||||||
if job:
|
|
||||||
job.on_start()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def key_fields(cls) -> Tuple[str, str]:
|
|
||||||
return ("job_id", "task_id")
|
|
||||||
|
|
||||||
def set_state(self, state: TaskState) -> None:
|
|
||||||
if self.state == state:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.state = state
|
|
||||||
if self.state in [TaskState.running, TaskState.setting_up]:
|
|
||||||
self.on_start()
|
|
||||||
|
|
||||||
self.save()
|
|
||||||
|
|
||||||
if self.state == TaskState.stopped:
|
|
||||||
if self.error:
|
|
||||||
send_event(
|
|
||||||
EventTaskFailed(
|
|
||||||
job_id=self.job_id,
|
|
||||||
task_id=self.task_id,
|
|
||||||
error=self.error,
|
|
||||||
user_info=self.user_info,
|
|
||||||
config=self.config,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
send_event(
|
|
||||||
EventTaskStopped(
|
|
||||||
job_id=self.job_id,
|
|
||||||
task_id=self.task_id,
|
|
||||||
user_info=self.user_info,
|
|
||||||
config=self.config,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
send_event(
|
|
||||||
EventTaskStateUpdated(
|
|
||||||
job_id=self.job_id,
|
|
||||||
task_id=self.task_id,
|
|
||||||
state=self.state,
|
|
||||||
end_time=self.end_time,
|
|
||||||
config=self.config,
|
|
||||||
)
|
|
||||||
)
|
|
@ -1,248 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Dict, Generator, List, Optional, Tuple, TypeVar
|
|
||||||
from uuid import UUID, uuid4
|
|
||||||
|
|
||||||
from onefuzztypes.enums import OS, PoolState, TaskState
|
|
||||||
from onefuzztypes.models import WorkSet, WorkUnit
|
|
||||||
from onefuzztypes.primitives import Container
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from ..azure.containers import blob_exists, get_container_sas_url
|
|
||||||
from ..azure.storage import StorageType
|
|
||||||
from ..jobs import Job
|
|
||||||
from ..workers.pools import Pool
|
|
||||||
from .config import build_task_config, get_setup_container
|
|
||||||
from .main import Task
|
|
||||||
|
|
||||||
HOURS = 60 * 60
|
|
||||||
|
|
||||||
# TODO: eventually, this should be tied to the pool.
|
|
||||||
MAX_TASKS_PER_SET = 10
|
|
||||||
|
|
||||||
|
|
||||||
A = TypeVar("A")
|
|
||||||
|
|
||||||
|
|
||||||
def chunks(items: List[A], size: int) -> Generator[List[A], None, None]:
|
|
||||||
return (items[x : x + size] for x in range(0, len(items), size))
|
|
||||||
|
|
||||||
|
|
||||||
def schedule_workset(workset: WorkSet, pool: Pool, count: int) -> bool:
|
|
||||||
if pool.state not in PoolState.available():
|
|
||||||
logging.info(
|
|
||||||
"pool not available for work: %s state: %s", pool.name, pool.state.name
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
for _ in range(count):
|
|
||||||
if not pool.schedule_workset(workset):
|
|
||||||
logging.error(
|
|
||||||
"unable to schedule workset. pool:%s workset:%s", pool.name, workset
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
# TODO - Once Pydantic supports hashable models, the Tuple should be replaced
|
|
||||||
# with a model.
|
|
||||||
#
|
|
||||||
# For info: https://github.com/samuelcolvin/pydantic/pull/1881
|
|
||||||
|
|
||||||
|
|
||||||
def bucket_tasks(tasks: List[Task]) -> Dict[Tuple, List[Task]]:
|
|
||||||
# buckets are hashed by:
|
|
||||||
# OS, JOB ID, vm sku & image (if available), pool name (if available),
|
|
||||||
# if the setup script requires rebooting, and a 'unique' value
|
|
||||||
#
|
|
||||||
# The unique value is set based on the following conditions:
|
|
||||||
# * if the task is set to run on more than one VM, than we assume it can't be shared
|
|
||||||
# * if the task is missing the 'colocate' flag or it's set to False
|
|
||||||
|
|
||||||
buckets: Dict[Tuple, List[Task]] = {}
|
|
||||||
|
|
||||||
for task in tasks:
|
|
||||||
vm: Optional[Tuple[str, str]] = None
|
|
||||||
pool: Optional[str] = None
|
|
||||||
unique: Optional[UUID] = None
|
|
||||||
|
|
||||||
# check for multiple VMs for pre-1.0.0 tasks
|
|
||||||
if task.config.vm:
|
|
||||||
vm = (task.config.vm.sku, task.config.vm.image)
|
|
||||||
if task.config.vm.count > 1:
|
|
||||||
unique = uuid4()
|
|
||||||
|
|
||||||
# check for multiple VMs for 1.0.0 and later tasks
|
|
||||||
if task.config.pool:
|
|
||||||
pool = task.config.pool.pool_name
|
|
||||||
if task.config.pool.count > 1:
|
|
||||||
unique = uuid4()
|
|
||||||
|
|
||||||
if not task.config.colocate:
|
|
||||||
unique = uuid4()
|
|
||||||
|
|
||||||
key = (
|
|
||||||
task.os,
|
|
||||||
task.job_id,
|
|
||||||
vm,
|
|
||||||
pool,
|
|
||||||
get_setup_container(task.config),
|
|
||||||
task.config.task.reboot_after_setup,
|
|
||||||
unique,
|
|
||||||
)
|
|
||||||
if key not in buckets:
|
|
||||||
buckets[key] = []
|
|
||||||
buckets[key].append(task)
|
|
||||||
|
|
||||||
return buckets
|
|
||||||
|
|
||||||
|
|
||||||
class BucketConfig(BaseModel):
|
|
||||||
count: int
|
|
||||||
reboot: bool
|
|
||||||
setup_container: Container
|
|
||||||
setup_script: Optional[str]
|
|
||||||
pool: Pool
|
|
||||||
|
|
||||||
|
|
||||||
def build_work_unit(task: Task) -> Optional[Tuple[BucketConfig, WorkUnit]]:
|
|
||||||
pool = task.get_pool()
|
|
||||||
if not pool:
|
|
||||||
logging.info("unable to find pool for task: %s", task.task_id)
|
|
||||||
return None
|
|
||||||
|
|
||||||
logging.info("scheduling task: %s", task.task_id)
|
|
||||||
|
|
||||||
job = Job.get(task.job_id)
|
|
||||||
if not job:
|
|
||||||
raise Exception(f"invalid job_id {task.job_id} for task {task.task_id}")
|
|
||||||
|
|
||||||
task_config = build_task_config(job, task)
|
|
||||||
|
|
||||||
setup_container = get_setup_container(task.config)
|
|
||||||
setup_script = None
|
|
||||||
|
|
||||||
if task.os == OS.windows and blob_exists(
|
|
||||||
setup_container, "setup.ps1", StorageType.corpus
|
|
||||||
):
|
|
||||||
setup_script = "setup.ps1"
|
|
||||||
if task.os == OS.linux and blob_exists(
|
|
||||||
setup_container, "setup.sh", StorageType.corpus
|
|
||||||
):
|
|
||||||
setup_script = "setup.sh"
|
|
||||||
|
|
||||||
reboot = False
|
|
||||||
count = 1
|
|
||||||
if task.config.pool:
|
|
||||||
count = task.config.pool.count
|
|
||||||
|
|
||||||
# NOTE: "is True" is required to handle Optional[bool]
|
|
||||||
reboot = task.config.task.reboot_after_setup is True
|
|
||||||
elif task.config.vm:
|
|
||||||
# this branch should go away when we stop letting people specify
|
|
||||||
# VM configs directly.
|
|
||||||
count = task.config.vm.count
|
|
||||||
|
|
||||||
# NOTE: "is True" is required to handle Optional[bool]
|
|
||||||
reboot = (
|
|
||||||
task.config.vm.reboot_after_setup is True
|
|
||||||
or task.config.task.reboot_after_setup is True
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise TypeError
|
|
||||||
|
|
||||||
work_unit = WorkUnit(
|
|
||||||
job_id=task_config.job_id,
|
|
||||||
task_id=task_config.task_id,
|
|
||||||
task_type=task_config.task_type,
|
|
||||||
config=task_config.json(exclude_none=True, exclude_unset=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
bucket_config = BucketConfig(
|
|
||||||
pool=pool,
|
|
||||||
count=count,
|
|
||||||
reboot=reboot,
|
|
||||||
setup_script=setup_script,
|
|
||||||
setup_container=setup_container,
|
|
||||||
)
|
|
||||||
|
|
||||||
return bucket_config, work_unit
|
|
||||||
|
|
||||||
|
|
||||||
def build_work_set(tasks: List[Task]) -> Optional[Tuple[BucketConfig, WorkSet]]:
|
|
||||||
task_ids = [x.task_id for x in tasks]
|
|
||||||
|
|
||||||
bucket_config: Optional[BucketConfig] = None
|
|
||||||
work_units = []
|
|
||||||
|
|
||||||
for task in tasks:
|
|
||||||
if task.config.prereq_tasks:
|
|
||||||
# if all of the prereqs are in this bucket, they will be
|
|
||||||
# scheduled together
|
|
||||||
if not all([task_id in task_ids for task_id in task.config.prereq_tasks]):
|
|
||||||
if not task.check_prereq_tasks():
|
|
||||||
continue
|
|
||||||
|
|
||||||
result = build_work_unit(task)
|
|
||||||
if not result:
|
|
||||||
continue
|
|
||||||
|
|
||||||
new_bucket_config, work_unit = result
|
|
||||||
if bucket_config is None:
|
|
||||||
bucket_config = new_bucket_config
|
|
||||||
else:
|
|
||||||
if bucket_config != new_bucket_config:
|
|
||||||
raise Exception(
|
|
||||||
f"bucket configs differ: {bucket_config} VS {new_bucket_config}"
|
|
||||||
)
|
|
||||||
|
|
||||||
work_units.append(work_unit)
|
|
||||||
|
|
||||||
if bucket_config:
|
|
||||||
setup_url = get_container_sas_url(
|
|
||||||
bucket_config.setup_container, StorageType.corpus, read=True, list_=True
|
|
||||||
)
|
|
||||||
|
|
||||||
work_set = WorkSet(
|
|
||||||
reboot=bucket_config.reboot,
|
|
||||||
script=(bucket_config.setup_script is not None),
|
|
||||||
setup_url=setup_url,
|
|
||||||
work_units=work_units,
|
|
||||||
)
|
|
||||||
return (bucket_config, work_set)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def schedule_tasks() -> None:
|
|
||||||
tasks: List[Task] = []
|
|
||||||
|
|
||||||
tasks = Task.search_states(states=[TaskState.waiting])
|
|
||||||
|
|
||||||
tasks_by_id = {x.task_id: x for x in tasks}
|
|
||||||
seen = set()
|
|
||||||
|
|
||||||
not_ready_count = 0
|
|
||||||
|
|
||||||
buckets = bucket_tasks(tasks)
|
|
||||||
|
|
||||||
for bucketed_tasks in buckets.values():
|
|
||||||
for chunk in chunks(bucketed_tasks, MAX_TASKS_PER_SET):
|
|
||||||
result = build_work_set(chunk)
|
|
||||||
if result is None:
|
|
||||||
continue
|
|
||||||
bucket_config, work_set = result
|
|
||||||
|
|
||||||
if schedule_workset(work_set, bucket_config.pool, bucket_config.count):
|
|
||||||
for work_unit in work_set.work_units:
|
|
||||||
task = tasks_by_id[work_unit.task_id]
|
|
||||||
task.set_state(TaskState.scheduled)
|
|
||||||
seen.add(task.task_id)
|
|
||||||
|
|
||||||
not_ready_count = len(tasks) - len(seen)
|
|
||||||
if not_ready_count > 0:
|
|
||||||
logging.info("tasks not ready: %d", not_ready_count)
|
|
@ -1,72 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from typing import Any, Dict, Optional, Union
|
|
||||||
|
|
||||||
from onefuzztypes.enums import TelemetryData, TelemetryEvent
|
|
||||||
from opencensus.ext.azure.log_exporter import AzureLogHandler
|
|
||||||
|
|
||||||
LOCAL_CLIENT: Optional[logging.Logger] = None
|
|
||||||
CENTRAL_CLIENT: Optional[logging.Logger] = None
|
|
||||||
|
|
||||||
|
|
||||||
def _get_client(environ_key: str) -> Optional[logging.Logger]:
|
|
||||||
key = os.environ.get(environ_key)
|
|
||||||
if key is None:
|
|
||||||
return None
|
|
||||||
client = logging.getLogger("onefuzz")
|
|
||||||
client.addHandler(AzureLogHandler(connection_string="InstrumentationKey=%s" % key))
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
def _central_client() -> Optional[logging.Logger]:
|
|
||||||
global CENTRAL_CLIENT
|
|
||||||
if not CENTRAL_CLIENT:
|
|
||||||
CENTRAL_CLIENT = _get_client("ONEFUZZ_TELEMETRY")
|
|
||||||
return CENTRAL_CLIENT
|
|
||||||
|
|
||||||
|
|
||||||
def _local_client() -> Union[None, Any, logging.Logger]:
|
|
||||||
global LOCAL_CLIENT
|
|
||||||
if not LOCAL_CLIENT:
|
|
||||||
LOCAL_CLIENT = _get_client("APPINSIGHTS_INSTRUMENTATIONKEY")
|
|
||||||
return LOCAL_CLIENT
|
|
||||||
|
|
||||||
|
|
||||||
# NOTE: All telemetry that is *NOT* using the ORM telemetry_include should
|
|
||||||
# go through this method
|
|
||||||
#
|
|
||||||
# This provides a point of inspection to know if it's data that is safe to
|
|
||||||
# log to the central OneFuzz telemetry point
|
|
||||||
def track_event(
|
|
||||||
event: TelemetryEvent, data: Dict[TelemetryData, Union[str, int]]
|
|
||||||
) -> None:
|
|
||||||
central = _central_client()
|
|
||||||
local = _local_client()
|
|
||||||
|
|
||||||
if local:
|
|
||||||
serialized = {k.name: v for (k, v) in data.items()}
|
|
||||||
local.info(event.name, extra={"custom_dimensions": serialized})
|
|
||||||
|
|
||||||
if event in TelemetryEvent.can_share() and central:
|
|
||||||
serialized = {
|
|
||||||
k.name: v for (k, v) in data.items() if k in TelemetryData.can_share()
|
|
||||||
}
|
|
||||||
central.info(event.name, extra={"custom_dimensions": serialized})
|
|
||||||
|
|
||||||
|
|
||||||
# NOTE: This should *only* be used for logging Telemetry data that uses
|
|
||||||
# the ORM telemetry_include method to limit data for telemetry.
|
|
||||||
def track_event_filtered(event: TelemetryEvent, data: Any) -> None:
|
|
||||||
central = _central_client()
|
|
||||||
local = _local_client()
|
|
||||||
|
|
||||||
if local:
|
|
||||||
local.info(event.name, extra={"custom_dimensions": data})
|
|
||||||
|
|
||||||
if central and event in TelemetryEvent.can_share():
|
|
||||||
central.info(event.name, extra={"custom_dimensions": data})
|
|
@ -1,114 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Dict, Optional, Type
|
|
||||||
|
|
||||||
from azure.core.exceptions import ResourceNotFoundError
|
|
||||||
from msrestazure.azure_exceptions import CloudError
|
|
||||||
from onefuzztypes.enums import UpdateType
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from .azure.queue import queue_object
|
|
||||||
from .azure.storage import StorageType
|
|
||||||
|
|
||||||
|
|
||||||
# This class isn't intended to be shared outside of the service
|
|
||||||
class Update(BaseModel):
|
|
||||||
update_type: UpdateType
|
|
||||||
PartitionKey: Optional[str]
|
|
||||||
RowKey: Optional[str]
|
|
||||||
method: Optional[str]
|
|
||||||
|
|
||||||
|
|
||||||
def queue_update(
|
|
||||||
update_type: UpdateType,
|
|
||||||
PartitionKey: Optional[str] = None,
|
|
||||||
RowKey: Optional[str] = None,
|
|
||||||
method: Optional[str] = None,
|
|
||||||
visibility_timeout: int = None,
|
|
||||||
) -> None:
|
|
||||||
logging.info(
|
|
||||||
"queuing type:%s id:[%s,%s] method:%s timeout: %s",
|
|
||||||
update_type.name,
|
|
||||||
PartitionKey,
|
|
||||||
RowKey,
|
|
||||||
method,
|
|
||||||
visibility_timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
update = Update(
|
|
||||||
update_type=update_type, PartitionKey=PartitionKey, RowKey=RowKey, method=method
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if not queue_object(
|
|
||||||
"update-queue",
|
|
||||||
update,
|
|
||||||
StorageType.config,
|
|
||||||
visibility_timeout=visibility_timeout,
|
|
||||||
):
|
|
||||||
logging.error("unable to queue update")
|
|
||||||
except (CloudError, ResourceNotFoundError) as err:
|
|
||||||
logging.error("GOT ERROR %s", repr(err))
|
|
||||||
logging.error("GOT ERROR %s", dir(err))
|
|
||||||
raise err
|
|
||||||
|
|
||||||
|
|
||||||
def execute_update(update: Update) -> None:
|
|
||||||
from .jobs import Job
|
|
||||||
from .orm import ORMMixin
|
|
||||||
from .proxy import Proxy
|
|
||||||
from .repro import Repro
|
|
||||||
from .tasks.main import Task
|
|
||||||
from .workers.nodes import Node
|
|
||||||
from .workers.pools import Pool
|
|
||||||
from .workers.scalesets import Scaleset
|
|
||||||
|
|
||||||
update_objects: Dict[UpdateType, Type[ORMMixin]] = {
|
|
||||||
UpdateType.Task: Task,
|
|
||||||
UpdateType.Job: Job,
|
|
||||||
UpdateType.Repro: Repro,
|
|
||||||
UpdateType.Proxy: Proxy,
|
|
||||||
UpdateType.Pool: Pool,
|
|
||||||
UpdateType.Node: Node,
|
|
||||||
UpdateType.Scaleset: Scaleset,
|
|
||||||
}
|
|
||||||
|
|
||||||
# TODO: remove these from being queued, these updates are handled elsewhere
|
|
||||||
if update.update_type == UpdateType.Scaleset:
|
|
||||||
return
|
|
||||||
|
|
||||||
if update.update_type in update_objects:
|
|
||||||
if update.PartitionKey is None or update.RowKey is None:
|
|
||||||
raise Exception("unsupported update: %s" % update)
|
|
||||||
|
|
||||||
obj = update_objects[update.update_type].get(update.PartitionKey, update.RowKey)
|
|
||||||
if not obj:
|
|
||||||
logging.error("unable find to obj to update %s", update)
|
|
||||||
return
|
|
||||||
|
|
||||||
if update.method and hasattr(obj, update.method):
|
|
||||||
logging.info("performing queued update: %s", update)
|
|
||||||
getattr(obj, update.method)()
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
state = getattr(obj, "state", None)
|
|
||||||
if state is None:
|
|
||||||
logging.error("queued update for object without state: %s", update)
|
|
||||||
return
|
|
||||||
func = getattr(obj, state.name, None)
|
|
||||||
if func is None:
|
|
||||||
logging.debug(
|
|
||||||
"no function to implement state: %s - %s", update, state.name
|
|
||||||
)
|
|
||||||
return
|
|
||||||
logging.info(
|
|
||||||
"performing queued update for state: %s - %s", update, state.name
|
|
||||||
)
|
|
||||||
func()
|
|
||||||
return
|
|
||||||
|
|
||||||
raise NotImplementedError("unimplemented update type: %s" % update.update_type.name)
|
|
@ -1,79 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
#
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import List, Optional
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
import azure.functions as func
|
|
||||||
import jwt
|
|
||||||
from memoization import cached
|
|
||||||
from onefuzztypes.enums import ErrorCode
|
|
||||||
from onefuzztypes.models import Error, Result, UserInfo
|
|
||||||
|
|
||||||
from .config import InstanceConfig
|
|
||||||
|
|
||||||
|
|
||||||
def get_bearer_token(request: func.HttpRequest) -> Optional[str]:
|
|
||||||
auth: str = request.headers.get("Authorization", None)
|
|
||||||
if not auth:
|
|
||||||
return None
|
|
||||||
|
|
||||||
parts = auth.split()
|
|
||||||
|
|
||||||
if len(parts) != 2:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if parts[0].lower() != "bearer":
|
|
||||||
return None
|
|
||||||
|
|
||||||
return parts[1]
|
|
||||||
|
|
||||||
|
|
||||||
def get_auth_token(request: func.HttpRequest) -> Optional[str]:
|
|
||||||
token = get_bearer_token(request)
|
|
||||||
if token is not None:
|
|
||||||
return token
|
|
||||||
|
|
||||||
token_header = request.headers.get("x-ms-token-aad-id-token", None)
|
|
||||||
if token_header is None:
|
|
||||||
return None
|
|
||||||
return str(token_header)
|
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl=60)
|
|
||||||
def get_allowed_tenants() -> List[str]:
|
|
||||||
config = InstanceConfig.fetch()
|
|
||||||
entries = [f"https://sts.windows.net/{x}/" for x in config.allowed_aad_tenants]
|
|
||||||
return entries
|
|
||||||
|
|
||||||
|
|
||||||
def parse_jwt_token(request: func.HttpRequest) -> Result[UserInfo]:
|
|
||||||
"""Obtains the Access Token from the Authorization Header"""
|
|
||||||
token_str = get_auth_token(request)
|
|
||||||
if token_str is None:
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.INVALID_REQUEST,
|
|
||||||
errors=["unable to find authorization token"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# The JWT token has already been verified by the azure authentication layer,
|
|
||||||
# but we need to verify the tenant is as we expect.
|
|
||||||
token = jwt.decode(token_str, options={"verify_signature": False})
|
|
||||||
|
|
||||||
if "iss" not in token:
|
|
||||||
return Error(
|
|
||||||
code=ErrorCode.INVALID_REQUEST, errors=["missing issuer from token"]
|
|
||||||
)
|
|
||||||
|
|
||||||
tenants = get_allowed_tenants()
|
|
||||||
if token["iss"] not in tenants:
|
|
||||||
logging.error("issuer not from allowed tenant: %s - %s", token["iss"], tenants)
|
|
||||||
return Error(code=ErrorCode.INVALID_REQUEST, errors=["unauthorized AAD issuer"])
|
|
||||||
|
|
||||||
application_id = UUID(token["appid"]) if "appid" in token else None
|
|
||||||
object_id = UUID(token["oid"]) if "oid" in token else None
|
|
||||||
upn = token.get("upn")
|
|
||||||
return UserInfo(application_id=application_id, object_id=object_id, upn=upn)
|
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user