delete python function (#2763)

* delete python function

* fix build
This commit is contained in:
Cheick Keita
2023-01-25 22:05:55 +00:00
committed by GitHub
parent d79e6d9864
commit 7de23af60c
186 changed files with 4 additions and 18165 deletions

View File

@ -29,13 +29,6 @@ echo "layout python3" >> .envrc
direnv allow
pip install -e .
echo "Install api-service"
cd /workspaces/onefuzz/src/api-service
echo "layout python3" >> .envrc
direnv allow
pip install -r requirements-dev.txt
cd __app__
pip install -r requirements.txt
cd /workspaces/onefuzz/src/utils
chmod u+x lint.sh

View File

@ -3,4 +3,3 @@ paths:
- src/agent
- src/pytypes
- src/deployment
- src/api-service/__app__

View File

@ -1,7 +0,0 @@
[flake8]
# Recommend matching the black line length (default 88),
# rather than using the flake8 default of 79:
max-line-length = 88
extend-ignore =
# See https://github.com/PyCQA/pycodestyle/issues/373
E203,

View File

@ -1,49 +0,0 @@
bin
obj
csx
.vs
edge
Publish
*.user
*.suo
*.cscfg
*.Cache
project.lock.json
/packages
/TestResults
/tools/NuGet.exe
/App_Data
/secrets
/data
.secrets
appsettings.json
local.settings.json
node_modules
dist
# Local python packages
.python_packages/
# Python Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
.dockerenv
host_secrets.json
local.settings.json
.mypy_cache/
__app__/onefuzzlib/build.id

View File

@ -1,9 +0,0 @@
If you are doing development on the API service, you can build and deploy directly to your own instance using the azure-functions-core-tools.
From the api-service directory, do the following:
func azure functionapp publish <instance>
While Azure Functions will restart your instance with the new code, it may take a while. It may be helpful to restart your instance after pushing by doing the following:
az functionapp restart -g <group> -n <instance>

View File

@ -1 +0,0 @@
local.settings.json

View File

@ -1,4 +0,0 @@
.direnv
.python_packages
__pycache__
.venv

View File

@ -1,17 +0,0 @@
import os
import certifi.core
def override_where() -> str:
"""overrides certifi.core.where to return actual location of cacert.pem"""
# see:
# https://github.com/Azure/azure-functions-durable-python/issues/194#issuecomment-710670377
# change this to match the location of cacert.pem
return os.path.abspath(
"cacert.pem"
) # or to whatever location you know contains the copy of cacert.pem
os.environ["REQUESTS_CA_BUNDLE"] = override_where()
certifi.core.where = override_where

View File

@ -1,53 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import azure.functions as func
from onefuzztypes.enums import ErrorCode, TaskState
from onefuzztypes.models import Error
from onefuzztypes.requests import CanScheduleRequest
from onefuzztypes.responses import CanSchedule
from ..onefuzzlib.endpoint_authorization import call_if_agent
from ..onefuzzlib.request import not_ok, ok, parse_request
from ..onefuzzlib.tasks.main import Task
from ..onefuzzlib.workers.nodes import Node
def post(req: func.HttpRequest) -> func.HttpResponse:
request = parse_request(CanScheduleRequest, req)
if isinstance(request, Error):
return not_ok(request, context="CanScheduleRequest")
node = Node.get_by_machine_id(request.machine_id)
if not node:
return not_ok(
Error(code=ErrorCode.UNABLE_TO_FIND, errors=["unable to find node"]),
context=request.machine_id,
)
allowed = True
work_stopped = False
if not node.can_process_new_work():
allowed = False
task = Task.get_by_task_id(request.task_id)
work_stopped = isinstance(task, Error) or task.state in TaskState.shutting_down()
if work_stopped:
allowed = False
if allowed:
allowed = not isinstance(node.acquire_scale_in_protection(), Error)
return ok(CanSchedule(allowed=allowed, work_stopped=work_stopped))
def main(req: func.HttpRequest) -> func.HttpResponse:
methods = {"POST": post}
method = methods[req.method]
result = call_if_agent(req, method)
return result

View File

@ -1,20 +0,0 @@
{
"scriptFile": "__init__.py",
"bindings": [
{
"authLevel": "anonymous",
"type": "httpTrigger",
"direction": "in",
"name": "req",
"methods": [
"post"
],
"route": "agents/can_schedule"
},
{
"type": "http",
"direction": "out",
"name": "$return"
}
]
}

View File

@ -1,50 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import azure.functions as func
from onefuzztypes.models import Error, NodeCommandEnvelope
from onefuzztypes.requests import NodeCommandDelete, NodeCommandGet
from onefuzztypes.responses import BoolResult, PendingNodeCommand
from ..onefuzzlib.endpoint_authorization import call_if_agent
from ..onefuzzlib.request import not_ok, ok, parse_request
from ..onefuzzlib.workers.nodes import NodeMessage
def get(req: func.HttpRequest) -> func.HttpResponse:
request = parse_request(NodeCommandGet, req)
if isinstance(request, Error):
return not_ok(request, context="NodeCommandGet")
messages = NodeMessage.get_messages(request.machine_id, num_results=1)
if messages:
command = messages[0].message
message_id = messages[0].message_id
envelope = NodeCommandEnvelope(command=command, message_id=message_id)
return ok(PendingNodeCommand(envelope=envelope))
else:
return ok(PendingNodeCommand(envelope=None))
def delete(req: func.HttpRequest) -> func.HttpResponse:
request = parse_request(NodeCommandDelete, req)
if isinstance(request, Error):
return not_ok(request, context="NodeCommandDelete")
message = NodeMessage.get(request.machine_id, request.message_id)
if message:
message.delete()
return ok(BoolResult(result=True))
def main(req: func.HttpRequest) -> func.HttpResponse:
methods = {"DELETE": delete, "GET": get}
method = methods[req.method]
result = call_if_agent(req, method)
return result

View File

@ -1,22 +0,0 @@
{
"scriptFile": "__init__.py",
"bindings": [
{
"authLevel": "anonymous",
"type": "httpTrigger",
"direction": "in",
"name": "req",
"methods": [
"get",
"post",
"delete"
],
"route": "agents/commands"
},
{
"type": "http",
"direction": "out",
"name": "$return"
}
]
}

View File

@ -1,79 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import azure.functions as func
from onefuzztypes.models import (
Error,
NodeEvent,
NodeEventEnvelope,
NodeStateUpdate,
Result,
WorkerEvent,
)
from onefuzztypes.responses import BoolResult
from ..onefuzzlib.agent_events import on_state_update, on_worker_event
from ..onefuzzlib.endpoint_authorization import call_if_agent
from ..onefuzzlib.request import not_ok, ok, parse_request
def process(envelope: NodeEventEnvelope) -> Result[None]:
logging.info(
"node event: machine_id: %s event: %s",
envelope.machine_id,
envelope.event.json(exclude_none=True),
)
if isinstance(envelope.event, NodeStateUpdate):
return on_state_update(envelope.machine_id, envelope.event)
if isinstance(envelope.event, WorkerEvent):
return on_worker_event(envelope.machine_id, envelope.event)
if isinstance(envelope.event, NodeEvent):
if envelope.event.state_update:
result = on_state_update(envelope.machine_id, envelope.event.state_update)
if result is not None:
return result
if envelope.event.worker_event:
result = on_worker_event(envelope.machine_id, envelope.event.worker_event)
if result is not None:
return result
return None
raise NotImplementedError("invalid node event: %s" % envelope)
def post(req: func.HttpRequest) -> func.HttpResponse:
envelope = parse_request(NodeEventEnvelope, req)
if isinstance(envelope, Error):
return not_ok(envelope, context="node event")
logging.info(
"node event: machine_id: %s event: %s",
envelope.machine_id,
envelope.event.json(exclude_none=True),
)
result = process(envelope)
if isinstance(result, Error):
logging.error(
"unable to process agent event. envelope:%s error:%s", envelope, result
)
return not_ok(result, context="node event")
return ok(BoolResult(result=True))
def main(req: func.HttpRequest) -> func.HttpResponse:
methods = {"POST": post}
method = methods[req.method]
result = call_if_agent(req, method)
return result

View File

@ -1,20 +0,0 @@
{
"scriptFile": "__init__.py",
"bindings": [
{
"authLevel": "anonymous",
"type": "httpTrigger",
"direction": "in",
"name": "req",
"methods": [
"post"
],
"route": "agents/events"
},
{
"type": "http",
"direction": "out",
"name": "$return"
}
]
}

View File

@ -1,119 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import datetime
import logging
from uuid import UUID
import azure.functions as func
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error
from onefuzztypes.requests import AgentRegistrationGet, AgentRegistrationPost
from onefuzztypes.responses import AgentRegistration
from ..onefuzzlib.azure.creds import get_agent_instance_url
from ..onefuzzlib.azure.queue import get_queue_sas
from ..onefuzzlib.azure.storage import StorageType
from ..onefuzzlib.endpoint_authorization import call_if_agent
from ..onefuzzlib.request import not_ok, ok, parse_uri
from ..onefuzzlib.workers.nodes import Node
from ..onefuzzlib.workers.pools import Pool
def create_registration_response(machine_id: UUID, pool: Pool) -> func.HttpResponse:
base_address = get_agent_instance_url()
events_url = "%s/api/agents/events" % base_address
commands_url = "%s/api/agents/commands" % base_address
work_queue = get_queue_sas(
pool.get_pool_queue(),
StorageType.corpus,
read=True,
update=True,
process=True,
duration=datetime.timedelta(hours=24),
)
return ok(
AgentRegistration(
events_url=events_url,
commands_url=commands_url,
work_queue=work_queue,
)
)
def get(req: func.HttpRequest) -> func.HttpResponse:
get_registration = parse_uri(AgentRegistrationGet, req)
if isinstance(get_registration, Error):
return not_ok(get_registration, context="agent registration")
agent_node = Node.get_by_machine_id(get_registration.machine_id)
if agent_node is None:
return not_ok(
Error(
code=ErrorCode.INVALID_REQUEST,
errors=[
"unable to find a registration associated with machine_id '%s'"
% get_registration.machine_id
],
),
context="agent registration",
)
else:
pool = Pool.get_by_name(agent_node.pool_name)
if isinstance(pool, Error):
return not_ok(
Error(
code=ErrorCode.INVALID_REQUEST,
errors=[
"unable to find a pool associated with the provided machine_id"
],
),
context="agent registration",
)
return create_registration_response(agent_node.machine_id, pool)
def post(req: func.HttpRequest) -> func.HttpResponse:
registration_request = parse_uri(AgentRegistrationPost, req)
if isinstance(registration_request, Error):
return not_ok(registration_request, context="agent registration")
logging.info(
"registration request: %s", registration_request.json(exclude_none=True)
)
pool = Pool.get_by_name(registration_request.pool_name)
if isinstance(pool, Error):
return not_ok(
Error(
code=ErrorCode.INVALID_REQUEST,
errors=["unable to find pool '%s'" % registration_request.pool_name],
),
context="agent registration",
)
node = Node.get_by_machine_id(registration_request.machine_id)
if node:
node.delete()
node = Node.create(
pool_id=pool.pool_id,
pool_name=pool.name,
machine_id=registration_request.machine_id,
scaleset_id=registration_request.scaleset_id,
version=registration_request.version,
)
return create_registration_response(node.machine_id, pool)
def main(req: func.HttpRequest) -> func.HttpResponse:
methods = {"POST": post, "GET": get}
method = methods[req.method]
result = call_if_agent(req, method)
return result

View File

@ -1,22 +0,0 @@
{
"scriptFile": "__init__.py",
"bindings": [
{
"authLevel": "anonymous",
"type": "httpTrigger",
"direction": "in",
"name": "req",
"methods": [
"get",
"post",
"delete"
],
"route": "agents/registration"
},
{
"type": "http",
"direction": "out",
"name": "$return"
}
]
}

View File

@ -1,96 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from typing import Optional, Union
import azure.functions as func
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error
from onefuzztypes.requests import ContainerCreate, ContainerDelete, ContainerGet
from onefuzztypes.responses import BoolResult, ContainerInfo, ContainerInfoBase
from ..onefuzzlib.azure.containers import (
create_container,
delete_container,
get_container_metadata,
get_container_sas_url,
get_containers,
)
from ..onefuzzlib.azure.storage import StorageType
from ..onefuzzlib.endpoint_authorization import call_if_user
from ..onefuzzlib.request import not_ok, ok, parse_request
def get(req: func.HttpRequest) -> func.HttpResponse:
request: Optional[Union[ContainerGet, Error]] = None
if req.get_body():
request = parse_request(ContainerGet, req)
if isinstance(request, Error):
return not_ok(request, context="container get")
if request is not None:
metadata = get_container_metadata(request.name, StorageType.corpus)
if metadata is None:
return not_ok(
Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid container"]),
context=request.name,
)
info = ContainerInfo(
name=request.name,
sas_url=get_container_sas_url(
request.name,
StorageType.corpus,
read=True,
write=True,
delete=True,
list_=True,
),
metadata=metadata,
)
return ok(info)
containers = get_containers(StorageType.corpus)
container_info = []
for name in containers:
container_info.append(ContainerInfoBase(name=name, metadata=containers[name]))
return ok(container_info)
def post(req: func.HttpRequest) -> func.HttpResponse:
request = parse_request(ContainerCreate, req)
if isinstance(request, Error):
return not_ok(request, context="container create")
logging.info("container - creating %s", request.name)
sas = create_container(request.name, StorageType.corpus, metadata=request.metadata)
if sas:
return ok(
ContainerInfo(name=request.name, sas_url=sas, metadata=request.metadata)
)
return not_ok(
Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid container"]),
context=request.name,
)
def delete(req: func.HttpRequest) -> func.HttpResponse:
request = parse_request(ContainerDelete, req)
if isinstance(request, Error):
return not_ok(request, context="container delete")
logging.info("container - deleting %s", request.name)
return ok(BoolResult(result=delete_container(request.name, StorageType.corpus)))
def main(req: func.HttpRequest) -> func.HttpResponse:
methods = {"GET": get, "POST": post, "DELETE": delete}
method = methods[req.method]
result = call_if_user(req, method)
return result

View File

@ -1,21 +0,0 @@
{
"scriptFile": "__init__.py",
"bindings": [
{
"authLevel": "anonymous",
"type": "httpTrigger",
"direction": "in",
"name": "req",
"methods": [
"get",
"post",
"delete"
]
},
{
"type": "http",
"direction": "out",
"name": "$return"
}
]
}

View File

@ -1,55 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from datetime import timedelta
import azure.functions as func
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error, FileEntry
from ..onefuzzlib.azure.containers import (
blob_exists,
container_exists,
get_file_sas_url,
)
from ..onefuzzlib.azure.storage import StorageType
from ..onefuzzlib.endpoint_authorization import call_if_user
from ..onefuzzlib.request import not_ok, parse_uri, redirect
def get(req: func.HttpRequest) -> func.HttpResponse:
request = parse_uri(FileEntry, req)
if isinstance(request, Error):
return not_ok(request, context="download")
if not container_exists(request.container, StorageType.corpus):
return not_ok(
Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid container"]),
context=request.container,
)
if not blob_exists(request.container, request.filename, StorageType.corpus):
return not_ok(
Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid filename"]),
context=request.filename,
)
return redirect(
get_file_sas_url(
request.container,
request.filename,
StorageType.corpus,
read=True,
duration=timedelta(minutes=5),
)
)
def main(req: func.HttpRequest) -> func.HttpResponse:
methods = {"GET": get}
method = methods[req.method]
result = call_if_user(req, method)
return result

View File

@ -1,19 +0,0 @@
{
"scriptFile": "__init__.py",
"bindings": [
{
"authLevel": "anonymous",
"type": "httpTrigger",
"direction": "in",
"name": "req",
"methods": [
"get"
]
},
{
"type": "http",
"direction": "out",
"name": "$return"
}
]
}

View File

@ -1,34 +0,0 @@
{
"version": "2.0",
"extensionBundle": {
"id": "Microsoft.Azure.Functions.ExtensionBundle",
"version": "[1.*, 2.0.0)"
},
"logging": {
"logLevel": {
"default": "Warning",
"Host": "Warning",
"Function": "Information",
"Host.Aggregator": "Warning",
"logging:logLevel:Worker": "Warning",
"logging:logLevel:Microsoft": "Warning",
"AzureFunctionsJobHost:logging:logLevel:Host.Function.Console": "Warning"
},
"applicationInsights": {
"samplingSettings": {
"isEnabled": true,
"InitialSamplingPercentage": 100,
"maxTelemetryItemsPerSecond": 20,
"excludedTypes": "Exception"
}
}
},
"extensions": {
"queues": {
"maxPollingInterval": "00:00:02",
"batchSize": 32,
"maxDequeueCount": 5
}
},
"functionTimeout": "00:15:00"
}

View File

@ -1,43 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import azure.functions as func
from onefuzztypes.responses import Info
from ..onefuzzlib.azure.creds import (
get_base_region,
get_base_resource_group,
get_insights_appid,
get_insights_instrumentation_key,
get_instance_id,
get_subscription,
)
from ..onefuzzlib.endpoint_authorization import call_if_user
from ..onefuzzlib.request import ok
from ..onefuzzlib.versions import versions
def get(req: func.HttpRequest) -> func.HttpResponse:
response = ok(
Info(
resource_group=get_base_resource_group(),
region=get_base_region(),
subscription=get_subscription(),
versions=versions(),
instance_id=get_instance_id(),
insights_appid=get_insights_appid(),
insights_instrumentation_key=get_insights_instrumentation_key(),
)
)
return response
def main(req: func.HttpRequest) -> func.HttpResponse:
methods = {"GET": get}
method = methods[req.method]
result = call_if_user(req, method)
return result

View File

@ -1,19 +0,0 @@
{
"scriptFile": "__init__.py",
"bindings": [
{
"authLevel": "anonymous",
"type": "httpTrigger",
"direction": "in",
"name": "req",
"methods": [
"get"
]
},
{
"type": "http",
"direction": "out",
"name": "$return"
}
]
}

View File

@ -1,77 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import azure.functions as func
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error
from onefuzztypes.requests import InstanceConfigUpdate
from ..onefuzzlib.azure.nsg import is_onefuzz_nsg, list_nsgs, set_allowed
from ..onefuzzlib.config import InstanceConfig
from ..onefuzzlib.endpoint_authorization import call_if_user, can_modify_config
from ..onefuzzlib.request import not_ok, ok, parse_request
def get(req: func.HttpRequest) -> func.HttpResponse:
return ok(InstanceConfig.fetch())
def post(req: func.HttpRequest) -> func.HttpResponse:
request = parse_request(InstanceConfigUpdate, req)
if isinstance(request, Error):
return not_ok(request, context="instance_config_update")
config = InstanceConfig.fetch()
if not can_modify_config(req, config):
return not_ok(
Error(code=ErrorCode.INVALID_PERMISSION, errors=["unauthorized"]),
context="instance_config_update",
)
update_nsg = False
if request.config.proxy_nsg_config and config.proxy_nsg_config:
request_config = request.config.proxy_nsg_config
current_config = config.proxy_nsg_config
if set(request_config.allowed_service_tags) != set(
current_config.allowed_service_tags
) or set(request_config.allowed_ips) != set(current_config.allowed_ips):
update_nsg = True
config.update(request.config)
config.save()
# Update All NSGs
if update_nsg:
nsgs = list_nsgs()
for nsg in nsgs:
logging.info(
"Checking if nsg: %s (%s) owned by OneFuzz" % (nsg.location, nsg.name)
)
if is_onefuzz_nsg(nsg.location, nsg.name):
result = set_allowed(nsg.location, request.config.proxy_nsg_config)
if isinstance(result, Error):
return not_ok(
Error(
code=ErrorCode.UNABLE_TO_CREATE,
errors=[
"Unable to update nsg %s due to %s"
% (nsg.location, result)
],
),
context="instance_config_update",
)
return ok(config)
def main(req: func.HttpRequest) -> func.HttpResponse:
methods = {"GET": get, "POST": post}
method = methods[req.method]
result = call_if_user(req, method)
return result

View File

@ -1,20 +0,0 @@
{
"scriptFile": "__init__.py",
"bindings": [
{
"authLevel": "anonymous",
"type": "httpTrigger",
"direction": "in",
"name": "req",
"methods": [
"get",
"post"
]
},
{
"type": "http",
"direction": "out",
"name": "$return"
}
]
}

View File

@ -1,42 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import azure.functions as func
from onefuzztypes.job_templates import JobTemplateRequest
from onefuzztypes.models import Error
from ..onefuzzlib.endpoint_authorization import call_if_user
from ..onefuzzlib.job_templates.templates import JobTemplateIndex
from ..onefuzzlib.request import not_ok, ok, parse_request
from ..onefuzzlib.user_credentials import parse_jwt_token
def get(req: func.HttpRequest) -> func.HttpResponse:
configs = JobTemplateIndex.get_configs()
return ok(configs)
def post(req: func.HttpRequest) -> func.HttpResponse:
request = parse_request(JobTemplateRequest, req)
if isinstance(request, Error):
return not_ok(request, context="JobTemplateRequest")
user_info = parse_jwt_token(req)
if isinstance(user_info, Error):
return not_ok(user_info, context="JobTemplateRequest")
job = JobTemplateIndex.execute(request, user_info)
if isinstance(job, Error):
return not_ok(job, context="JobTemplateRequest")
return ok(job)
def main(req: func.HttpRequest) -> func.HttpResponse:
methods = {"GET": get, "POST": post}
method = methods[req.method]
result = call_if_user(req, method)
return result

View File

@ -1,20 +0,0 @@
{
"scriptFile": "__init__.py",
"bindings": [
{
"authLevel": "anonymous",
"type": "httpTrigger",
"direction": "in",
"name": "req",
"methods": [
"get",
"post"
]
},
{
"type": "http",
"direction": "out",
"name": "$return"
}
]
}

View File

@ -1,69 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import azure.functions as func
from onefuzztypes.enums import ErrorCode
from onefuzztypes.job_templates import (
JobTemplateDelete,
JobTemplateGet,
JobTemplateUpload,
)
from onefuzztypes.models import Error
from onefuzztypes.responses import BoolResult
from ..onefuzzlib.endpoint_authorization import call_if_user
from ..onefuzzlib.job_templates.templates import JobTemplateIndex
from ..onefuzzlib.request import not_ok, ok, parse_request
def get(req: func.HttpRequest) -> func.HttpResponse:
request = parse_request(JobTemplateGet, req)
if isinstance(request, Error):
return not_ok(request, context="JobTemplateGet")
if request.name:
entry = JobTemplateIndex.get_base_entry(request.name)
if entry is None:
return not_ok(
Error(code=ErrorCode.INVALID_REQUEST, errors=["no such job template"]),
context="JobTemplateGet",
)
return ok(entry.template)
templates = JobTemplateIndex.get_index()
return ok(templates)
def post(req: func.HttpRequest) -> func.HttpResponse:
request = parse_request(JobTemplateUpload, req)
if isinstance(request, Error):
return not_ok(request, context="JobTemplateUpload")
entry = JobTemplateIndex(name=request.name, template=request.template)
result = entry.save()
if isinstance(result, Error):
return not_ok(result, context="JobTemplateUpload")
return ok(BoolResult(result=True))
def delete(req: func.HttpRequest) -> func.HttpResponse:
request = parse_request(JobTemplateDelete, req)
if isinstance(request, Error):
return not_ok(request, context="JobTemplateDelete")
entry = JobTemplateIndex.get(request.name)
if entry is not None:
entry.delete()
return ok(BoolResult(result=entry is not None))
def main(req: func.HttpRequest) -> func.HttpResponse:
methods = {"GET": get, "POST": post, "DELETE": delete}
method = methods[req.method]
result = call_if_user(req, method)
return result

View File

@ -1,22 +0,0 @@
{
"scriptFile": "__init__.py",
"bindings": [
{
"authLevel": "anonymous",
"type": "httpTrigger",
"direction": "in",
"name": "req",
"methods": [
"get",
"post",
"delete"
],
"route": "job_templates/manage"
},
{
"type": "http",
"direction": "out",
"name": "$return"
}
]
}

View File

@ -1,109 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from uuid import uuid4
import azure.functions as func
from onefuzztypes.enums import ContainerType, ErrorCode, JobState
from onefuzztypes.models import Error, JobConfig, JobTaskInfo
from onefuzztypes.primitives import Container
from onefuzztypes.requests import JobGet, JobSearch
from ..onefuzzlib.azure.containers import create_container
from ..onefuzzlib.azure.storage import StorageType
from ..onefuzzlib.endpoint_authorization import call_if_user
from ..onefuzzlib.jobs import Job
from ..onefuzzlib.request import not_ok, ok, parse_request
from ..onefuzzlib.tasks.main import Task
from ..onefuzzlib.user_credentials import parse_jwt_token
def get(req: func.HttpRequest) -> func.HttpResponse:
request = parse_request(JobSearch, req)
if isinstance(request, Error):
return not_ok(request, context="jobs")
if request.job_id:
job = Job.get(request.job_id)
if not job:
return not_ok(
Error(code=ErrorCode.INVALID_JOB, errors=["no such job"]),
context=request.job_id,
)
task_info = []
for task in Task.search_states(job_id=request.job_id):
task_info.append(
JobTaskInfo(
task_id=task.task_id, type=task.config.task.type, state=task.state
)
)
job.task_info = task_info
return ok(job)
jobs = Job.search_states(states=request.state)
return ok(jobs)
def post(req: func.HttpRequest) -> func.HttpResponse:
request = parse_request(JobConfig, req)
if isinstance(request, Error):
return not_ok(request, context="jobs create")
user_info = parse_jwt_token(req)
if isinstance(user_info, Error):
return not_ok(user_info, context="jobs create")
job = Job(job_id=uuid4(), config=request, user_info=user_info)
# create the job logs container
log_container_sas = create_container(
Container(f"logs-{job.job_id}"),
StorageType.corpus,
metadata={"container_type": ContainerType.logs.name},
)
if not log_container_sas:
return not_ok(
Error(
code=ErrorCode.UNABLE_TO_CREATE_CONTAINER,
errors=["unable to create logs container"],
),
context="logs",
)
sep_index = log_container_sas.find("?")
if sep_index > 0:
log_container = log_container_sas[:sep_index]
else:
log_container = log_container_sas
job.config.logs = log_container
job.save()
return ok(job)
def delete(req: func.HttpRequest) -> func.HttpResponse:
request = parse_request(JobGet, req)
if isinstance(request, Error):
return not_ok(request, context="jobs delete")
job = Job.get(request.job_id)
if not job:
return not_ok(
Error(code=ErrorCode.INVALID_JOB, errors=["no such job"]),
context=request.job_id,
)
if job.state not in [JobState.stopped, JobState.stopping]:
job.state = JobState.stopping
job.save()
return ok(job)
def main(req: func.HttpRequest) -> func.HttpResponse:
methods = {"GET": get, "POST": post, "DELETE": delete}
method = methods[req.method]
result = call_if_user(req, method)
return result

View File

@ -1,21 +0,0 @@
{
"scriptFile": "__init__.py",
"bindings": [
{
"authLevel": "anonymous",
"type": "httpTrigger",
"direction": "in",
"name": "req",
"methods": [
"get",
"post",
"delete"
]
},
{
"type": "http",
"direction": "out",
"name": "$return"
}
]
}

View File

@ -1,34 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import azure.functions as func
from ..onefuzzlib.endpoint_authorization import call_if_user
# This endpoint handles the signalr negotation
# As we do not differentiate from clients at this time, we pass the Functions runtime
# provided connection straight to the client
#
# For more info:
# https://docs.microsoft.com/en-us/azure/azure-signalr/signalr-concept-internals
def main(req: func.HttpRequest, connectionInfoJson: str) -> func.HttpResponse:
# NOTE: this is a sub-method because the call_if* do not support callbacks with
# additional arguments at this time. Once call_if* supports additional arguments,
# this should be made a generic function
def post(req: func.HttpRequest) -> func.HttpResponse:
return func.HttpResponse(
connectionInfoJson,
status_code=200,
headers={"Content-type": "application/json"},
)
methods = {"POST": post}
method = methods[req.method]
result = call_if_user(req, method)
return result

View File

@ -1,25 +0,0 @@
{
"scriptFile": "__init__.py",
"bindings": [
{
"authLevel": "anonymous",
"type": "httpTrigger",
"direction": "in",
"name": "req",
"methods": [
"post"
]
},
{
"type": "http",
"direction": "out",
"name": "$return"
},
{
"type": "signalRConnectionInfo",
"direction": "in",
"name": "connectionInfoJson",
"hubName": "dashboard"
}
]
}

View File

@ -1,122 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import azure.functions as func
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error
from onefuzztypes.requests import NodeGet, NodeSearch, NodeUpdate
from onefuzztypes.responses import BoolResult
from ..onefuzzlib.endpoint_authorization import call_if_user, check_require_admins
from ..onefuzzlib.request import not_ok, ok, parse_request
from ..onefuzzlib.workers.nodes import Node, NodeMessage, NodeTasks
def get(req: func.HttpRequest) -> func.HttpResponse:
request = parse_request(NodeSearch, req)
if isinstance(request, Error):
return not_ok(request, context="pool get")
if request.machine_id:
node = Node.get_by_machine_id(request.machine_id)
if not node:
return not_ok(
Error(code=ErrorCode.UNABLE_TO_FIND, errors=["unable to find node"]),
context=request.machine_id,
)
if isinstance(node, Error):
return not_ok(node, context=request.machine_id)
node.tasks = [n for n in NodeTasks.get_by_machine_id(request.machine_id)]
node.messages = [
x.message for x in NodeMessage.get_messages(request.machine_id)
]
return ok(node)
nodes = Node.search_states(
states=request.state,
pool_name=request.pool_name,
scaleset_id=request.scaleset_id,
)
return ok(nodes)
def post(req: func.HttpRequest) -> func.HttpResponse:
request = parse_request(NodeUpdate, req)
if isinstance(request, Error):
return not_ok(request, context="NodeUpdate")
answer = check_require_admins(req)
if isinstance(answer, Error):
return not_ok(answer, context="NodeUpdate")
node = Node.get_by_machine_id(request.machine_id)
if not node:
return not_ok(
Error(code=ErrorCode.UNABLE_TO_FIND, errors=["unable to find node"]),
context=request.machine_id,
)
if request.debug_keep_node is not None:
node.debug_keep_node = request.debug_keep_node
node.save()
return ok(BoolResult(result=True))
def delete(req: func.HttpRequest) -> func.HttpResponse:
request = parse_request(NodeGet, req)
if isinstance(request, Error):
return not_ok(request, context="NodeDelete")
answer = check_require_admins(req)
if isinstance(answer, Error):
return not_ok(answer, context="NodeDelete")
node = Node.get_by_machine_id(request.machine_id)
if not node:
return not_ok(
Error(code=ErrorCode.UNABLE_TO_FIND, errors=["unable to find node"]),
context=request.machine_id,
)
node.set_halt()
if node.debug_keep_node:
node.debug_keep_node = False
node.save()
return ok(BoolResult(result=True))
def patch(req: func.HttpRequest) -> func.HttpResponse:
request = parse_request(NodeGet, req)
if isinstance(request, Error):
return not_ok(request, context="NodeReimage")
answer = check_require_admins(req)
if isinstance(answer, Error):
return not_ok(answer, context="NodeReimage")
node = Node.get_by_machine_id(request.machine_id)
if not node:
return not_ok(
Error(code=ErrorCode.UNABLE_TO_FIND, errors=["unable to find node"]),
context=request.machine_id,
)
node.stop(done=True)
if node.debug_keep_node:
node.debug_keep_node = False
node.save()
return ok(BoolResult(result=True))
def main(req: func.HttpRequest) -> func.HttpResponse:
methods = {"GET": get, "PATCH": patch, "DELETE": delete, "POST": post}
method = methods[req.method]
result = call_if_user(req, method)
return result

View File

@ -1,22 +0,0 @@
{
"scriptFile": "__init__.py",
"bindings": [
{
"authLevel": "anonymous",
"type": "httpTrigger",
"direction": "in",
"name": "req",
"methods": [
"get",
"patch",
"delete",
"post"
]
},
{
"type": "http",
"direction": "out",
"name": "$return"
}
]
}

View File

@ -1,40 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import azure.functions as func
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error
from onefuzztypes.requests import NodeAddSshKey
from onefuzztypes.responses import BoolResult
from ..onefuzzlib.endpoint_authorization import call_if_user
from ..onefuzzlib.request import not_ok, ok, parse_request
from ..onefuzzlib.workers.nodes import Node
def post(req: func.HttpRequest) -> func.HttpResponse:
request = parse_request(NodeAddSshKey, req)
if isinstance(request, Error):
return not_ok(request, context="NodeAddSshKey")
node = Node.get_by_machine_id(request.machine_id)
if not node:
return not_ok(
Error(code=ErrorCode.UNABLE_TO_FIND, errors=["unable to find node"]),
context=request.machine_id,
)
result = node.add_ssh_public_key(public_key=request.public_key)
if isinstance(result, Error):
return not_ok(result, context="NodeAddSshKey")
return ok(BoolResult(result=True))
def main(req: func.HttpRequest) -> func.HttpResponse:
methods = {"POST": post}
method = methods[req.method]
result = call_if_user(req, method)
return result

View File

@ -1,20 +0,0 @@
{
"scriptFile": "__init__.py",
"bindings": [
{
"authLevel": "anonymous",
"type": "httpTrigger",
"direction": "in",
"name": "req",
"methods": [
"post"
],
"route": "node/add_ssh_key"
},
{
"type": "http",
"direction": "out",
"name": "$return"
}
]
}

View File

@ -1,70 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import azure.functions as func
from onefuzztypes.models import Error
from onefuzztypes.requests import (
NotificationCreate,
NotificationGet,
NotificationSearch,
)
from ..onefuzzlib.endpoint_authorization import call_if_user
from ..onefuzzlib.notifications.main import Notification
from ..onefuzzlib.request import not_ok, ok, parse_request
def get(req: func.HttpRequest) -> func.HttpResponse:
logging.info("notification search")
request = parse_request(NotificationSearch, req)
if isinstance(request, Error):
return not_ok(request, context="notification search")
if request.container:
entries = Notification.search(query={"container": request.container})
else:
entries = Notification.search()
return ok(entries)
def post(req: func.HttpRequest) -> func.HttpResponse:
logging.info("adding notification hook")
request = parse_request(NotificationCreate, req)
if isinstance(request, Error):
return not_ok(request, context="notification create")
entry = Notification.create(
container=request.container,
config=request.config,
replace_existing=request.replace_existing,
)
if isinstance(entry, Error):
return not_ok(entry, context="notification create")
return ok(entry)
def delete(req: func.HttpRequest) -> func.HttpResponse:
request = parse_request(NotificationGet, req)
if isinstance(request, Error):
return not_ok(request, context="notification delete")
entry = Notification.get_by_id(request.notification_id)
if isinstance(entry, Error):
return not_ok(entry, context="notification delete")
entry.delete()
return ok(entry)
def main(req: func.HttpRequest) -> func.HttpResponse:
methods = {"GET": get, "POST": post, "DELETE": delete}
method = methods[req.method]
result = call_if_user(req, method)
return result

View File

@ -1,21 +0,0 @@
{
"scriptFile": "__init__.py",
"bindings": [
{
"authLevel": "anonymous",
"type": "httpTrigger",
"direction": "in",
"name": "req",
"methods": [
"get",
"post",
"delete"
]
},
{
"type": "http",
"direction": "out",
"name": "$return"
}
]
}

View File

@ -1,5 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: disable=W0612,C0111
__version__ = "0.0.0"

View File

@ -1,271 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import datetime
import logging
from typing import Optional, cast
from uuid import UUID
from onefuzztypes.enums import (
ErrorCode,
NodeState,
NodeTaskState,
TaskDebugFlag,
TaskState,
)
from onefuzztypes.models import (
Error,
NodeDoneEventData,
NodeSettingUpEventData,
NodeStateUpdate,
Result,
WorkerDoneEvent,
WorkerEvent,
WorkerRunningEvent,
)
from .task_event import TaskEvent
from .tasks.main import Task
from .workers.nodes import Node, NodeTasks
MAX_OUTPUT_SIZE = 4096
def get_node(machine_id: UUID) -> Result[Node]:
node = Node.get_by_machine_id(machine_id)
if not node:
return Error(code=ErrorCode.INVALID_NODE, errors=["unable to find node"])
return node
def on_state_update(
machine_id: UUID,
state_update: NodeStateUpdate,
) -> Result[None]:
state = state_update.state
node = get_node(machine_id)
if isinstance(node, Error):
if state == NodeState.done:
logging.warning(
"unable to process state update event. machine_id:"
f"{machine_id} state event:{state_update} error:{node}"
)
return None
return node
if state == NodeState.free:
if node.reimage_requested or node.delete_requested:
logging.info("stopping free node with reset flags: %s", node.machine_id)
node.stop()
return None
if node.could_shrink_scaleset():
logging.info("stopping free node to resize scaleset: %s", node.machine_id)
node.set_halt()
return None
if state == NodeState.init:
if node.delete_requested:
logging.info("stopping node (init and delete_requested): %s", machine_id)
node.stop()
return None
# not checking reimage_requested, as nodes only send 'init' state once. If
# they send 'init' with reimage_requested, it's because the node was reimaged
# successfully.
node.reimage_requested = False
node.initialized_at = datetime.datetime.now(datetime.timezone.utc)
node.set_state(state)
return None
logging.info("node state update: %s from:%s to:%s", machine_id, node.state, state)
node.set_state(state)
if state == NodeState.free:
logging.info("node now available for work: %s", machine_id)
elif state == NodeState.setting_up:
# Model-validated.
#
# This field will be required in the future.
# For now, it is optional for back compat.
setting_up_data = cast(
Optional[NodeSettingUpEventData],
state_update.data,
)
if setting_up_data:
if not setting_up_data.tasks:
return Error(
code=ErrorCode.INVALID_REQUEST,
errors=["setup without tasks. machine_id: %s", str(machine_id)],
)
for task_id in setting_up_data.tasks:
task = Task.get_by_task_id(task_id)
if isinstance(task, Error):
return task
logging.info(
"node starting task. machine_id: %s job_id: %s task_id: %s",
machine_id,
task.job_id,
task.task_id,
)
# The task state may be `running` if it has `vm_count` > 1, and
# another node is concurrently executing the task. If so, leave
# the state as-is, to represent the max progress made.
#
# Other states we would want to preserve are excluded by the
# outermost conditional check.
if task.state not in [TaskState.running, TaskState.setting_up]:
task.set_state(TaskState.setting_up)
# Note: we set the node task state to `setting_up`, even though
# the task itself may be `running`.
node_task = NodeTasks(
machine_id=machine_id,
task_id=task_id,
state=NodeTaskState.setting_up,
)
node_task.save()
elif state == NodeState.done:
# Model-validated.
#
# This field will be required in the future.
# For now, it is optional for back compat.
done_data = cast(Optional[NodeDoneEventData], state_update.data)
error = None
if done_data:
if done_data.error:
error_text = done_data.json(exclude_none=True)
error = Error(
code=ErrorCode.TASK_FAILED,
errors=[error_text],
)
logging.error(
"node 'done' with error: machine_id:%s, data:%s",
machine_id,
error_text,
)
# if tasks are running on the node when it reports as Done
# those are stopped early
node.mark_tasks_stopped_early(error=error)
node.to_reimage(done=True)
return None
def on_worker_event_running(
machine_id: UUID, event: WorkerRunningEvent
) -> Result[None]:
task = Task.get_by_task_id(event.task_id)
if isinstance(task, Error):
return task
node = get_node(machine_id)
if isinstance(node, Error):
return node
if node.state not in NodeState.ready_for_reset():
node.set_state(NodeState.busy)
node_task = NodeTasks(
machine_id=machine_id, task_id=event.task_id, state=NodeTaskState.running
)
node_task.save()
if task.state in TaskState.shutting_down():
logging.info(
"ignoring task start from node. "
"machine_id:%s job_id:%s task_id:%s (state: %s)",
machine_id,
task.job_id,
task.task_id,
task.state,
)
return None
logging.info(
"task started on node. machine_id:%s job_id%s task_id:%s",
machine_id,
task.job_id,
task.task_id,
)
task.set_state(TaskState.running)
task_event = TaskEvent(
task_id=task.task_id,
machine_id=machine_id,
event_data=WorkerEvent(running=event),
)
task_event.save()
return None
def on_worker_event_done(machine_id: UUID, event: WorkerDoneEvent) -> Result[None]:
task = Task.get_by_task_id(event.task_id)
if isinstance(task, Error):
return task
node = get_node(machine_id)
if isinstance(node, Error):
return node
if event.exit_status.success:
logging.info(
"task done. %s:%s status:%s", task.job_id, task.task_id, event.exit_status
)
task.mark_stopping()
if (
task.config.debug
and TaskDebugFlag.keep_node_on_completion in task.config.debug
):
node.debug_keep_node = True
node.save()
else:
task.mark_failed(
Error(
code=ErrorCode.TASK_FAILED,
errors=[
"task failed. exit_status:%s" % event.exit_status,
event.stdout[-MAX_OUTPUT_SIZE:],
event.stderr[-MAX_OUTPUT_SIZE:],
],
)
)
if task.config.debug and (
TaskDebugFlag.keep_node_on_failure in task.config.debug
or TaskDebugFlag.keep_node_on_completion in task.config.debug
):
node.debug_keep_node = True
node.save()
if not node.debug_keep_node:
node_task = NodeTasks.get(machine_id, event.task_id)
if node_task:
node_task.delete()
event.stdout = event.stdout[-MAX_OUTPUT_SIZE:]
event.stderr = event.stderr[-MAX_OUTPUT_SIZE:]
task_event = TaskEvent(
task_id=task.task_id, machine_id=machine_id, event_data=WorkerEvent(done=event)
)
task_event.save()
return None
def on_worker_event(machine_id: UUID, event: WorkerEvent) -> Result[None]:
if event.running:
return on_worker_event_running(machine_id, event.running)
elif event.done:
return on_worker_event_done(machine_id, event.done)
else:
raise NotImplementedError

View File

@ -1,171 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import math
from typing import List
from onefuzztypes.enums import NodeState, ScalesetState
from onefuzztypes.models import AutoScaleConfig, TaskPool
from .tasks.main import Task
from .workers.nodes import Node
from .workers.pools import Pool
from .workers.scalesets import Scaleset
def scale_up(pool: Pool, scalesets: List[Scaleset], nodes_needed: int) -> None:
logging.info("Scaling up")
autoscale_config = pool.autoscale
if not isinstance(autoscale_config, AutoScaleConfig):
return
for scaleset in scalesets:
if scaleset.state in [ScalesetState.running, ScalesetState.resize]:
max_size = min(scaleset.max_size(), autoscale_config.scaleset_size)
logging.info(
"scaleset:%s size:%d max_size:%d",
scaleset.scaleset_id,
scaleset.size,
max_size,
)
if scaleset.size < max_size:
current_size = scaleset.size
if nodes_needed <= max_size - current_size:
scaleset.set_size(current_size + nodes_needed)
nodes_needed = 0
else:
scaleset.set_size(max_size)
nodes_needed = nodes_needed - (max_size - current_size)
else:
continue
if nodes_needed == 0:
return
for _ in range(
math.ceil(
nodes_needed
/ min(
Scaleset.scaleset_max_size(autoscale_config.image),
autoscale_config.scaleset_size,
)
)
):
logging.info("Creating Scaleset for Pool %s", pool.name)
max_nodes_scaleset = min(
Scaleset.scaleset_max_size(autoscale_config.image),
autoscale_config.scaleset_size,
nodes_needed,
)
if not autoscale_config.region:
raise Exception("Region is missing")
Scaleset.create(
pool_name=pool.name,
vm_sku=autoscale_config.vm_sku,
image=autoscale_config.image,
region=autoscale_config.region,
size=max_nodes_scaleset,
spot_instances=autoscale_config.spot_instances,
ephemeral_os_disks=autoscale_config.ephemeral_os_disks,
tags={"pool": pool.name},
)
nodes_needed -= max_nodes_scaleset
def scale_down(scalesets: List[Scaleset], nodes_to_remove: int) -> None:
logging.info("Scaling down")
for scaleset in scalesets:
num_of_nodes = len(Node.search_states(scaleset_id=scaleset.scaleset_id))
if scaleset.size != num_of_nodes and scaleset.state not in [
ScalesetState.resize,
ScalesetState.shutdown,
ScalesetState.halt,
]:
scaleset.set_state(ScalesetState.resize)
free_nodes = Node.search_states(
scaleset_id=scaleset.scaleset_id,
states=[NodeState.free],
)
nodes = []
for node in free_nodes:
if not node.delete_requested:
nodes.append(node)
logging.info("Scaleset: %s, #Free Nodes: %s", scaleset.scaleset_id, len(nodes))
if nodes and nodes_to_remove > 0:
max_nodes_remove = min(len(nodes), nodes_to_remove)
# All nodes in scaleset are free. Can shutdown VMSS
if max_nodes_remove >= scaleset.size and len(nodes) >= scaleset.size:
scaleset.set_state(ScalesetState.shutdown)
nodes_to_remove = nodes_to_remove - scaleset.size
for node in nodes:
node.set_shutdown()
continue
# Resize of VMSS needed
scaleset.set_size(scaleset.size - max_nodes_remove)
nodes_to_remove = nodes_to_remove - max_nodes_remove
scaleset.set_state(ScalesetState.resize)
def get_vm_count(tasks: List[Task]) -> int:
count = 0
for task in tasks:
task_pool = task.get_pool()
if (
not task_pool
or not isinstance(task_pool, Pool)
or not isinstance(task.config.pool, TaskPool)
):
continue
count += task.config.pool.count
return count
def autoscale_pool(pool: Pool) -> None:
logging.info("autoscale: %s", pool.autoscale)
if not pool.autoscale:
return
# get all the tasks (count not stopped) for the pool
tasks = Task.get_tasks_by_pool_name(pool.name)
logging.info("Pool: %s, #Tasks %d", pool.name, len(tasks))
num_of_tasks = get_vm_count(tasks)
nodes_needed = max(num_of_tasks, pool.autoscale.min_size)
if pool.autoscale.max_size:
nodes_needed = min(nodes_needed, pool.autoscale.max_size)
# do scaleset logic match with pool
# get all the scalesets for the pool
scalesets = Scaleset.search_by_pool(pool.name)
pool_resize = False
for scaleset in scalesets:
if scaleset.state in ScalesetState.modifying():
pool_resize = True
break
nodes_needed = nodes_needed - scaleset.size
if pool_resize:
return
logging.info("Pool: %s, #Nodes Needed: %d", pool.name, nodes_needed)
if nodes_needed > 0:
# resizing scaleset or creating new scaleset.
scale_up(pool, scalesets, nodes_needed)
elif nodes_needed < 0:
for scaleset in scalesets:
nodes = Node.search_states(scaleset_id=scaleset.scaleset_id)
for node in nodes:
if node.delete_requested:
nodes_needed += 1
if nodes_needed < 0:
scale_down(scalesets, abs(nodes_needed))

View File

@ -1,37 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import subprocess # nosec - used for ssh key generation
import tempfile
from typing import Tuple
from uuid import uuid4
from onefuzztypes.models import Authentication
def generate_keypair() -> Tuple[str, str]:
with tempfile.TemporaryDirectory() as tmpdir:
filename = os.path.join(tmpdir, "key")
cmd = ["ssh-keygen", "-t", "rsa", "-f", filename, "-P", "", "-b", "2048"]
subprocess.check_output(cmd) # nosec - all arguments are under our control
with open(filename, "r") as handle:
private = handle.read()
with open(filename + ".pub", "r") as handle:
public = handle.read().strip()
return (public, private)
def build_auth() -> Authentication:
public_key, private_key = generate_keypair()
auth = Authentication(
password=str(uuid4()), public_key=public_key, private_key=private_key
)
return auth

View File

@ -1,328 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import uuid
from datetime import timedelta
from typing import Any, Dict, Optional, Union
from uuid import UUID
from azure.core.exceptions import ResourceNotFoundError
from azure.mgmt.monitor.models import (
AutoscaleProfile,
AutoscaleSettingResource,
ComparisonOperationType,
DiagnosticSettingsResource,
LogSettings,
MetricStatisticType,
MetricTrigger,
RetentionPolicy,
ScaleAction,
ScaleCapacity,
ScaleDirection,
ScaleRule,
ScaleType,
TimeAggregationType,
)
from msrestazure.azure_exceptions import CloudError
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error
from onefuzztypes.primitives import Region
from .creds import (
get_base_region,
get_base_resource_group,
get_subscription,
retry_on_auth_failure,
)
from .log_analytics import get_workspace_id
from .monitor import get_monitor_client
@retry_on_auth_failure()
def get_auto_scale_settings(
vmss: UUID,
) -> Union[Optional[AutoscaleSettingResource], Error]:
logging.info("Getting auto scale settings for %s" % vmss)
client = get_monitor_client()
resource_group = get_base_resource_group()
try:
auto_scale_collections = client.autoscale_settings.list_by_resource_group(
resource_group
)
for auto_scale in auto_scale_collections:
if str(auto_scale.target_resource_uri).endswith(str(vmss)):
logging.info("Found auto scale settings for %s" % vmss)
return auto_scale
except (ResourceNotFoundError, CloudError):
return Error(
code=ErrorCode.INVALID_CONFIGURATION,
errors=[
"Failed to check if scaleset %s already has an autoscale resource"
% vmss
],
)
return None
@retry_on_auth_failure()
def add_auto_scale_to_vmss(
vmss: UUID, auto_scale_profile: AutoscaleProfile
) -> Optional[Error]:
logging.info("Checking scaleset %s for existing auto scale resources" % vmss)
existing_auto_scale_resource = get_auto_scale_settings(vmss)
if isinstance(existing_auto_scale_resource, Error):
return existing_auto_scale_resource
if existing_auto_scale_resource is not None:
logging.warning("Scaleset %s already has auto scale resource" % vmss)
return None
auto_scale_resource = create_auto_scale_resource_for(
vmss, get_base_region(), auto_scale_profile
)
if isinstance(auto_scale_resource, Error):
return auto_scale_resource
diagnostics_resource = setup_auto_scale_diagnostics(
auto_scale_resource.id, auto_scale_resource.name, get_workspace_id()
)
if isinstance(diagnostics_resource, Error):
return diagnostics_resource
return None
def update_auto_scale(auto_scale_resource: AutoscaleSettingResource) -> Optional[Error]:
logging.info("Updating auto scale resource: %s" % auto_scale_resource.name)
client = get_monitor_client()
resource_group = get_base_resource_group()
try:
auto_scale_resource = client.autoscale_settings.create_or_update(
resource_group, auto_scale_resource.name, auto_scale_resource
)
logging.info(
"Successfully updated auto scale resource: %s" % auto_scale_resource.name
)
except (ResourceNotFoundError, CloudError):
return Error(
code=ErrorCode.UNABLE_TO_UPDATE,
errors=[
"unable to update auto scale resource with name: %s and profile: %s"
% (auto_scale_resource.name, auto_scale_resource)
],
)
return None
def create_auto_scale_resource_for(
resource_id: UUID, location: Region, profile: AutoscaleProfile
) -> Union[AutoscaleSettingResource, Error]:
logging.info("Creating auto scale resource for: %s" % resource_id)
client = get_monitor_client()
resource_group = get_base_resource_group()
subscription = get_subscription()
scaleset_uri = (
"/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachineScaleSets/%s" # noqa: E501
% (subscription, resource_group, resource_id)
)
params: Dict[str, Any] = {
"location": location,
"profiles": [profile],
"target_resource_uri": scaleset_uri,
"enabled": True,
}
try:
auto_scale_resource = client.autoscale_settings.create_or_update(
resource_group, str(uuid.uuid4()), params
)
logging.info(
"Successfully created auto scale resource %s for %s"
% (auto_scale_resource.id, resource_id)
)
return auto_scale_resource
except (ResourceNotFoundError, CloudError):
return Error(
code=ErrorCode.UNABLE_TO_CREATE,
errors=[
"unable to create auto scale resource for resource: %s with profile: %s"
% (resource_id, profile)
],
)
def create_auto_scale_profile(
queue_uri: str,
min: int,
max: int,
default: int,
scale_out_amount: int,
scale_out_cooldown: int,
scale_in_amount: int,
scale_in_cooldown: int,
) -> AutoscaleProfile:
return AutoscaleProfile(
name=str(uuid.uuid4()),
capacity=ScaleCapacity(minimum=min, maximum=max, default=max),
# Auto scale tuning guidance:
# https://docs.microsoft.com/en-us/azure/architecture/best-practices/auto-scaling
rules=[
ScaleRule(
metric_trigger=MetricTrigger(
metric_name="ApproximateMessageCount",
metric_resource_uri=queue_uri,
# Check every 15 minutes
time_grain=timedelta(minutes=15),
# The average amount of messages there are in the pool queue
time_aggregation=TimeAggregationType.AVERAGE,
statistic=MetricStatisticType.COUNT,
# Over the past 15 minutes
time_window=timedelta(minutes=15),
# When there's more than 1 message in the pool queue
operator=ComparisonOperationType.GREATER_THAN_OR_EQUAL,
threshold=1,
divide_per_instance=False,
),
scale_action=ScaleAction(
direction=ScaleDirection.INCREASE,
type=ScaleType.CHANGE_COUNT,
value=scale_out_amount,
cooldown=timedelta(minutes=scale_out_cooldown),
),
),
# Scale in
ScaleRule(
# Scale in if no work in the past 20 mins
metric_trigger=MetricTrigger(
metric_name="ApproximateMessageCount",
metric_resource_uri=queue_uri,
# Check every 10 minutes
time_grain=timedelta(minutes=10),
# The average amount of messages there are in the pool queue
time_aggregation=TimeAggregationType.AVERAGE,
statistic=MetricStatisticType.SUM,
# Over the past 10 minutes
time_window=timedelta(minutes=10),
# When there's no messages in the pool queue
operator=ComparisonOperationType.EQUALS,
threshold=0,
divide_per_instance=False,
),
scale_action=ScaleAction(
direction=ScaleDirection.DECREASE,
type=ScaleType.CHANGE_COUNT,
value=scale_in_amount,
cooldown=timedelta(minutes=scale_in_cooldown),
),
),
],
)
def get_auto_scale_profile(scaleset_id: UUID) -> AutoscaleProfile:
logging.info("Getting scaleset %s for existing auto scale resources" % scaleset_id)
client = get_monitor_client()
resource_group = get_base_resource_group()
auto_scale_resource = None
try:
auto_scale_collections = client.autoscale_settings.list_by_resource_group(
resource_group
)
for auto_scale in auto_scale_collections:
if str(auto_scale.target_resource_uri).endswith(str(scaleset_id)):
auto_scale_resource = auto_scale
auto_scale_profiles = auto_scale_resource.profiles
if len(auto_scale_profiles) != 1:
logging.info(
"Found more than one autoscaling profile for scaleset %s"
% scaleset_id
)
return auto_scale_profiles[0]
except (ResourceNotFoundError, CloudError):
return Error(
code=ErrorCode.INVALID_CONFIGURATION,
errors=["Failed to query scaleset %s autoscale resource" % scaleset_id],
)
def default_auto_scale_profile(queue_uri: str, scaleset_size: int) -> AutoscaleProfile:
return create_auto_scale_profile(
queue_uri, 1, scaleset_size, scaleset_size, 1, 10, 1, 5
)
def shutdown_scaleset_rule(queue_uri: str) -> ScaleRule:
return ScaleRule(
# Scale in if there are 0 or more messages in the queue (aka: every time)
metric_trigger=MetricTrigger(
metric_name="ApproximateMessageCount",
metric_resource_uri=queue_uri,
# Check every 10 minutes
time_grain=timedelta(minutes=5),
# The average amount of messages there are in the pool queue
time_aggregation=TimeAggregationType.AVERAGE,
statistic=MetricStatisticType.SUM,
# Over the past 10 minutes
time_window=timedelta(minutes=5),
operator=ComparisonOperationType.GREATER_THAN_OR_EQUAL,
threshold=0,
divide_per_instance=False,
),
scale_action=ScaleAction(
direction=ScaleDirection.DECREASE,
type=ScaleType.CHANGE_COUNT,
value=1,
cooldown=timedelta(minutes=5),
),
)
def setup_auto_scale_diagnostics(
auto_scale_resource_uri: str,
auto_scale_resource_name: str,
log_analytics_workspace_id: str,
) -> Union[DiagnosticSettingsResource, Error]:
logging.info("Setting up diagnostics for auto scale")
client = get_monitor_client()
log_settings = LogSettings(
enabled=True,
category_group="allLogs",
retention_policy=RetentionPolicy(enabled=True, days=30),
)
params: Dict[str, Any] = {
"logs": [log_settings],
"workspace_id": log_analytics_workspace_id,
}
try:
diagnostics = client.diagnostic_settings.create_or_update(
auto_scale_resource_uri, "%s-diagnostics" % auto_scale_resource_name, params
)
logging.info(
"Diagnostics created for auto scale resource: %s" % auto_scale_resource_uri
)
return diagnostics
except (ResourceNotFoundError, CloudError):
return Error(
code=ErrorCode.UNABLE_TO_CREATE,
errors=[
"unable to setup diagnostics for auto scale resource: %s"
% (auto_scale_resource_uri)
],
)

View File

@ -1,9 +0,0 @@
from azure.mgmt.compute import ComputeManagementClient
from memoization import cached
from .creds import get_identity, get_subscription
@cached
def get_compute_client() -> ComputeManagementClient:
return ComputeManagementClient(get_identity(), get_subscription())

View File

@ -1,388 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import datetime
import logging
import os
import urllib.parse
from typing import Dict, Optional, Tuple, Union, cast
from azure.common import AzureHttpError, AzureMissingResourceHttpError
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
from azure.storage.blob import (
BlobClient,
BlobSasPermissions,
BlobServiceClient,
ContainerClient,
ContainerSasPermissions,
generate_blob_sas,
generate_container_sas,
)
from memoization import cached
from onefuzztypes.primitives import Container
from .storage import (
StorageType,
choose_account,
get_accounts,
get_storage_account_name_key,
get_storage_account_name_key_by_name,
)
CONTAINER_SAS_DEFAULT_DURATION = datetime.timedelta(days=30)
def get_url(account_name: str) -> str:
return f"https://{account_name}.blob.core.windows.net/"
@cached
def get_blob_service(account_id: str) -> BlobServiceClient:
logging.debug("getting blob container (account_id: %s)", account_id)
account_name, account_key = get_storage_account_name_key(account_id)
account_url = get_url(account_name)
service = BlobServiceClient(account_url=account_url, credential=account_key)
return service
def container_metadata(
container: Container, account_id: str
) -> Optional[Dict[str, str]]:
try:
result = (
get_blob_service(account_id)
.get_container_client(container)
.get_container_properties()
)
return cast(Dict[str, str], result)
except AzureHttpError:
pass
return None
def find_container(
container: Container, storage_type: StorageType
) -> Optional[ContainerClient]:
accounts = get_accounts(storage_type)
# check secondary accounts first by searching in reverse.
#
# By implementation, the primary account is specified first, followed by
# any secondary accounts.
#
# Secondary accounts, if they exist, are preferred for containers and have
# increased IOP rates, this should be a slight optimization
for account in reversed(accounts):
client = get_blob_service(account).get_container_client(container)
if client.exists():
return client
return None
def container_exists(container: Container, storage_type: StorageType) -> bool:
return find_container(container, storage_type) is not None
def get_containers(storage_type: StorageType) -> Dict[str, Dict[str, str]]:
containers: Dict[str, Dict[str, str]] = {}
for account_id in get_accounts(storage_type):
containers.update(
{
x.name: x.metadata
for x in get_blob_service(account_id).list_containers(
include_metadata=True
)
}
)
return containers
def get_container_metadata(
container: Container, storage_type: StorageType
) -> Optional[Dict[str, str]]:
client = find_container(container, storage_type)
if client is None:
return None
result = client.get_container_properties().metadata
return cast(Dict[str, str], result)
def add_container_sas_url(
container_url: str, duration: datetime.timedelta = CONTAINER_SAS_DEFAULT_DURATION
) -> str:
parsed = urllib.parse.urlparse(container_url)
query = urllib.parse.parse_qs(parsed.query)
if "sig" in query:
return container_url
else:
start, expiry = sas_time_window(duration)
account_name = parsed.netloc.split(".")[0]
account_key = get_storage_account_name_key_by_name(account_name)
sas_token = generate_container_sas(
account_name=account_name,
container_name=parsed.path.split("/")[1],
account_key=account_key,
permission=ContainerSasPermissions(
read=True, write=True, delete=True, list=True
),
expiry=expiry,
start=start,
)
return f"{container_url}?{sas_token}"
def get_or_create_container_client(
container: Container,
storage_type: StorageType,
metadata: Optional[Dict[str, str]],
) -> Optional[ContainerClient]:
client = find_container(container, storage_type)
if client is None:
account = choose_account(storage_type)
client = get_blob_service(account).get_container_client(container)
try:
client.create_container(metadata=metadata)
except (ResourceExistsError, AzureHttpError) as err:
# note: resource exists error happens during creation if the container
# is being deleted
logging.error(
(
"unable to create container. account: %s "
"container: %s metadata: %s - %s"
),
account,
container,
metadata,
err,
)
return None
return client
def create_container(
container: Container,
storage_type: StorageType,
metadata: Optional[Dict[str, str]],
) -> Optional[str]:
client = get_or_create_container_client(container, storage_type, metadata)
if client is None:
return None
return get_container_sas_url_service(
client,
read=True,
write=True,
delete=True,
list_=True,
)
def delete_container(container: Container, storage_type: StorageType) -> bool:
accounts = get_accounts(storage_type)
deleted = False
for account in accounts:
service = get_blob_service(account)
try:
service.delete_container(container)
deleted = True
except ResourceNotFoundError:
pass
return deleted
def sas_time_window(
duration: datetime.timedelta,
) -> Tuple[datetime.datetime, datetime.datetime]:
# SAS URLs are valid 6 hours earlier, primarily to work around dev
# workstations having out-of-sync time. Additionally, SAS URLs are stopped
# 15 minutes later than requested based on "Be careful with SAS start time"
# guidance.
# Ref: https://docs.microsoft.com/en-us/azure/storage/common/storage-sas-overview
SAS_START_TIME_DELTA = datetime.timedelta(hours=6)
SAS_END_TIME_DELTA = datetime.timedelta(minutes=15)
now = datetime.datetime.utcnow()
start = now - SAS_START_TIME_DELTA
expiry = now + duration + SAS_END_TIME_DELTA
return (start, expiry)
def get_container_sas_url_service(
client: ContainerClient,
*,
read: bool = False,
write: bool = False,
delete: bool = False,
list_: bool = False,
delete_previous_version: bool = False,
tag: bool = False,
duration: datetime.timedelta = CONTAINER_SAS_DEFAULT_DURATION,
) -> str:
account_name = client.account_name
container_name = client.container_name
account_key = get_storage_account_name_key_by_name(account_name)
start, expiry = sas_time_window(duration)
sas = generate_container_sas(
account_name,
container_name,
account_key=account_key,
permission=ContainerSasPermissions(
read=read,
write=write,
delete=delete,
list=list_,
delete_previous_version=delete_previous_version,
tag=tag,
),
start=start,
expiry=expiry,
)
with_sas = ContainerClient(
get_url(account_name),
container_name=container_name,
credential=sas,
)
return cast(str, with_sas.url)
def get_container_sas_url(
container: Container,
storage_type: StorageType,
*,
read: bool = False,
write: bool = False,
delete: bool = False,
list_: bool = False,
) -> str:
client = find_container(container, storage_type)
if not client:
raise Exception("unable to create container sas for missing container")
return get_container_sas_url_service(
client,
read=read,
write=write,
delete=delete,
list_=list_,
)
def get_file_url(container: Container, name: str, storage_type: StorageType) -> str:
client = find_container(container, storage_type)
if not client:
raise Exception("unable to find container: %s - %s" % (container, storage_type))
# get_url has a trailing '/'
return f"{get_url(client.account_name)}{container}/{name}"
def get_file_sas_url(
container: Container,
name: str,
storage_type: StorageType,
*,
read: bool = False,
add: bool = False,
create: bool = False,
write: bool = False,
delete: bool = False,
delete_previous_version: bool = False,
tag: bool = False,
duration: datetime.timedelta = CONTAINER_SAS_DEFAULT_DURATION,
) -> str:
client = find_container(container, storage_type)
if not client:
raise Exception("unable to find container: %s - %s" % (container, storage_type))
account_key = get_storage_account_name_key_by_name(client.account_name)
start, expiry = sas_time_window(duration)
permission = BlobSasPermissions(
read=read,
add=add,
create=create,
write=write,
delete=delete,
delete_previous_version=delete_previous_version,
tag=tag,
)
sas = generate_blob_sas(
client.account_name,
container,
name,
account_key=account_key,
permission=permission,
expiry=expiry,
start=start,
)
with_sas = BlobClient(
get_url(client.account_name),
container,
name,
credential=sas,
)
return cast(str, with_sas.url)
def save_blob(
container: Container,
name: str,
data: Union[str, bytes],
storage_type: StorageType,
) -> None:
client = find_container(container, storage_type)
if not client:
raise Exception("unable to find container: %s - %s" % (container, storage_type))
client.get_blob_client(name).upload_blob(data, overwrite=True)
def get_blob(
container: Container, name: str, storage_type: StorageType
) -> Optional[bytes]:
client = find_container(container, storage_type)
if not client:
return None
try:
return cast(
bytes, client.get_blob_client(name).download_blob().content_as_bytes()
)
except AzureMissingResourceHttpError:
return None
def blob_exists(container: Container, name: str, storage_type: StorageType) -> bool:
client = find_container(container, storage_type)
if not client:
return False
return cast(bool, client.get_blob_client(name).exists())
def delete_blob(container: Container, name: str, storage_type: StorageType) -> bool:
client = find_container(container, storage_type)
if not client:
return False
try:
client.get_blob_client(name).delete_blob()
return True
except AzureMissingResourceHttpError:
return False
def auth_download_url(container: Container, filename: str) -> str:
instance = os.environ["ONEFUZZ_INSTANCE"]
return "%s/api/download?%s" % (
instance,
urllib.parse.urlencode({"container": container, "filename": filename}),
)

View File

@ -1,246 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import functools
import logging
import os
import urllib.parse
from typing import Any, Callable, Dict, List, Optional, TypeVar, cast
from uuid import UUID
import requests
from azure.core.exceptions import ClientAuthenticationError
from azure.identity import DefaultAzureCredential
from azure.keyvault.secrets import SecretClient
from azure.mgmt.resource import ResourceManagementClient
from azure.mgmt.subscription import SubscriptionClient
from memoization import cached
from msrestazure.azure_active_directory import MSIAuthentication
from msrestazure.tools import parse_resource_id
from onefuzztypes.primitives import Container, Region
from .monkeypatch import allow_more_workers, reduce_logging
# https://docs.microsoft.com/en-us/graph/api/overview?view=graph-rest-1.0
GRAPH_RESOURCE = "https://graph.microsoft.com"
GRAPH_RESOURCE_ENDPOINT = "https://graph.microsoft.com/v1.0"
@cached
def get_msi() -> MSIAuthentication:
allow_more_workers()
reduce_logging()
return MSIAuthentication()
@cached
def get_identity() -> DefaultAzureCredential:
allow_more_workers()
reduce_logging()
return DefaultAzureCredential()
@cached
def get_base_resource_group() -> Any: # should be str
return parse_resource_id(os.environ["ONEFUZZ_RESOURCE_GROUP"])["resource_group"]
@cached
def get_base_region() -> Region:
client = ResourceManagementClient(
credential=get_identity(), subscription_id=get_subscription()
)
group = client.resource_groups.get(get_base_resource_group())
return Region(group.location)
@cached
def get_subscription() -> Any: # should be str
return parse_resource_id(os.environ["ONEFUZZ_DATA_STORAGE"])["subscription"]
@cached
def get_insights_instrumentation_key() -> Any: # should be str
return os.environ["APPINSIGHTS_INSTRUMENTATIONKEY"]
@cached
def get_insights_appid() -> str:
return os.environ["APPINSIGHTS_APPID"]
@cached
def get_instance_name() -> str:
return os.environ["ONEFUZZ_INSTANCE_NAME"]
@cached
def get_instance_url() -> str:
return "https://%s.azurewebsites.net" % get_instance_name()
@cached
def get_agent_instance_url() -> str:
return get_instance_url()
@cached
def get_instance_id() -> UUID:
from .containers import get_blob
from .storage import StorageType
blob = get_blob(Container("base-config"), "instance_id", StorageType.config)
if blob is None:
raise Exception("missing instance_id")
return UUID(blob.decode())
DAY_IN_SECONDS = 60 * 60 * 24
@cached(ttl=DAY_IN_SECONDS)
def get_regions() -> List[Region]:
subscription = get_subscription()
client = SubscriptionClient(credential=get_identity())
locations = client.subscriptions.list_locations(subscription)
return sorted([Region(x.name) for x in locations])
class GraphQueryError(Exception):
def __init__(self, message: str, status_code: Optional[int]) -> None:
super(GraphQueryError, self).__init__(message)
self.message = message
self.status_code = status_code
def query_microsoft_graph(
method: str,
resource: str,
params: Optional[Dict] = None,
body: Optional[Dict] = None,
) -> Dict:
cred = get_identity()
access_token = cred.get_token(f"{GRAPH_RESOURCE}/.default")
url = urllib.parse.urljoin(f"{GRAPH_RESOURCE_ENDPOINT}/", resource)
headers = {
"Authorization": "Bearer %s" % access_token.token,
"Content-Type": "application/json",
}
response = requests.request(
method=method, url=url, headers=headers, params=params, json=body
)
if 200 <= response.status_code < 300:
if response.content and response.content.strip():
json = response.json()
if isinstance(json, Dict):
return json
else:
raise GraphQueryError(
"invalid data expected a json object: HTTP"
f" {response.status_code} - {json}",
response.status_code,
)
else:
return {}
else:
error_text = str(response.content, encoding="utf-8", errors="backslashreplace")
raise GraphQueryError(
f"request did not succeed: HTTP {response.status_code} - {error_text}",
response.status_code,
)
def query_microsoft_graph_list(
method: str,
resource: str,
params: Optional[Dict] = None,
body: Optional[Dict] = None,
) -> List[Any]:
result = query_microsoft_graph(
method,
resource,
params,
body,
)
value = result.get("value")
if isinstance(value, list):
return value
else:
raise GraphQueryError("Expected data containing a list of values", None)
@cached
def get_scaleset_identity_resource_path() -> str:
scaleset_id_name = "%s-scalesetid" % get_instance_name()
resource_group_path = "/subscriptions/%s/resourceGroups/%s/providers" % (
get_subscription(),
get_base_resource_group(),
)
return "%s/Microsoft.ManagedIdentity/userAssignedIdentities/%s" % (
resource_group_path,
scaleset_id_name,
)
@cached
def get_scaleset_principal_id() -> UUID:
api_version = "2018-11-30" # matches the apiversion in the deployment template
client = ResourceManagementClient(
credential=get_identity(), subscription_id=get_subscription()
)
uid = client.resources.get_by_id(get_scaleset_identity_resource_path(), api_version)
return UUID(uid.properties["principalId"])
@cached
def get_keyvault_client(vault_url: str) -> SecretClient:
return SecretClient(vault_url=vault_url, credential=DefaultAzureCredential())
def clear_azure_client_cache() -> None:
# clears the memoization of the Azure clients.
from .compute import get_compute_client
from .containers import get_blob_service
from .network_mgmt_client import get_network_client
from .storage import get_mgmt_client
# currently memoization.cache does not project the wrapped function's types.
# As a workaround, CI comments out the `cached` wrapper, then runs the type
# validation. This enables calling the wrapper's clear_cache if it's not
# disabled.
for func in [
get_msi,
get_identity,
get_compute_client,
get_blob_service,
get_network_client,
get_mgmt_client,
]:
clear_func = getattr(func, "clear_cache", None)
if clear_func is not None:
clear_func()
T = TypeVar("T", bound=Callable[..., Any])
class retry_on_auth_failure:
def __call__(self, func: T) -> T:
@functools.wraps(func)
def decorated(*args, **kwargs): # type: ignore
try:
return func(*args, **kwargs)
except ClientAuthenticationError as err:
logging.warning(
"clearing authentication cache after auth failure: %s", err
)
clear_azure_client_cache()
return func(*args, **kwargs)
return cast(T, decorated)

View File

@ -1,29 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from typing import Any
from azure.core.exceptions import ResourceNotFoundError
from msrestazure.azure_exceptions import CloudError
from .compute import get_compute_client
def list_disks(resource_group: str) -> Any:
logging.info("listing disks %s", resource_group)
compute_client = get_compute_client()
return compute_client.disks.list_by_resource_group(resource_group)
def delete_disk(resource_group: str, name: str) -> bool:
logging.info("deleting disks %s : %s", resource_group, name)
compute_client = get_compute_client()
try:
compute_client.disks.begin_delete(resource_group, name)
return True
except (ResourceNotFoundError, CloudError) as err:
logging.error("unable to delete disk: %s", err)
return False

View File

@ -1,51 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Dict, List, Protocol
from uuid import UUID
from ..config import InstanceConfig
from .creds import query_microsoft_graph_list
class GroupMembershipChecker(Protocol):
def is_member(self, group_ids: List[UUID], member_id: UUID) -> bool:
"""Check if member is part of at least one of the groups"""
if member_id in group_ids:
return True
groups = self.get_groups(member_id)
for g in group_ids:
if g in groups:
return True
return False
def get_groups(self, member_id: UUID) -> List[UUID]:
"""Gets all the groups of the provided member"""
def create_group_membership_checker() -> GroupMembershipChecker:
config = InstanceConfig.fetch()
if config.group_membership:
return StaticGroupMembership(config.group_membership)
else:
return AzureADGroupMembership()
class AzureADGroupMembership(GroupMembershipChecker):
def get_groups(self, member_id: UUID) -> List[UUID]:
response = query_microsoft_graph_list(
method="GET", resource=f"users/{member_id}/transitiveMemberOf"
)
return response
class StaticGroupMembership(GroupMembershipChecker):
def __init__(self, memberships: Dict[str, List[UUID]]):
self.memberships = memberships
def get_groups(self, member_id: UUID) -> List[UUID]:
return self.memberships.get(str(member_id), [])

View File

@ -1,55 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Union
from azure.core.exceptions import ResourceNotFoundError
from memoization import cached
from msrestazure.azure_exceptions import CloudError
from msrestazure.tools import parse_resource_id
from onefuzztypes.enums import OS, ErrorCode
from onefuzztypes.models import Error
from onefuzztypes.primitives import Region
from .compute import get_compute_client
@cached(ttl=60)
def get_os(region: Region, image: str) -> Union[Error, OS]:
client = get_compute_client()
# The dict returned here may not have any defined keys.
#
# See: https://github.com/Azure/msrestazure-for-python/blob/v0.6.3/msrestazure/tools.py#L134 # noqa: E501
parsed = parse_resource_id(image)
if "resource_group" in parsed:
if parsed["type"] == "galleries":
try:
# See: https://docs.microsoft.com/en-us/rest/api/compute/gallery-images/get#galleryimage # noqa: E501
name = client.gallery_images.get(
parsed["resource_group"], parsed["name"], parsed["child_name_1"]
).os_type.lower()
except (ResourceNotFoundError, CloudError) as err:
return Error(code=ErrorCode.INVALID_IMAGE, errors=[str(err)])
else:
try:
# See: https://docs.microsoft.com/en-us/rest/api/compute/images/get
name = client.images.get(
parsed["resource_group"], parsed["name"]
).storage_profile.os_disk.os_type.lower()
except (ResourceNotFoundError, CloudError) as err:
return Error(code=ErrorCode.INVALID_IMAGE, errors=[str(err)])
else:
publisher, offer, sku, version = image.split(":")
try:
if version == "latest":
version = client.virtual_machine_images.list(
region, publisher, offer, sku, top=1
)[0].name
name = client.virtual_machine_images.get(
region, publisher, offer, sku, version
).os_disk_image.operating_system.lower()
except (ResourceNotFoundError, CloudError) as err:
return Error(code=ErrorCode.INVALID_IMAGE, errors=[str(err)])
return OS[name]

View File

@ -1,168 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
from typing import Any, Dict, Optional, Union
from uuid import UUID
from azure.core.exceptions import ResourceNotFoundError
from azure.mgmt.network.models import Subnet
from msrestazure.azure_exceptions import CloudError
from msrestazure.tools import parse_resource_id
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error
from onefuzztypes.primitives import Region
from .creds import get_base_resource_group
from .network import Network
from .network_mgmt_client import get_network_client
from .nsg import NSG
from .vmss import get_instance_id
def get_scaleset_instance_ip(scaleset: UUID, machine_id: UUID) -> Optional[str]:
instance = get_instance_id(scaleset, machine_id)
if not isinstance(instance, str):
return None
resource_group = get_base_resource_group()
client = get_network_client()
intf = client.network_interfaces.list_virtual_machine_scale_set_network_interfaces(
resource_group, str(scaleset)
)
try:
for interface in intf:
resource = parse_resource_id(interface.virtual_machine.id)
if resource.get("resource_name") != instance:
continue
for config in interface.ip_configurations:
if config.private_ip_address is None:
continue
return str(config.private_ip_address)
except (ResourceNotFoundError, CloudError):
# this can fail if an interface is removed during the iteration
pass
return None
def get_ip(resource_group: str, name: str) -> Optional[Any]:
logging.info("getting ip %s:%s", resource_group, name)
network_client = get_network_client()
try:
return network_client.public_ip_addresses.get(resource_group, name)
except (ResourceNotFoundError, CloudError):
return None
def delete_ip(resource_group: str, name: str) -> Any:
logging.info("deleting ip %s:%s", resource_group, name)
network_client = get_network_client()
return network_client.public_ip_addresses.begin_delete(resource_group, name)
def create_ip(resource_group: str, name: str, region: Region) -> Any:
logging.info("creating ip for %s:%s in %s", resource_group, name, region)
network_client = get_network_client()
params: Dict[str, Union[str, Dict[str, str]]] = {
"location": region,
"public_ip_allocation_method": "Dynamic",
}
if "ONEFUZZ_OWNER" in os.environ:
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
return network_client.public_ip_addresses.begin_create_or_update(
resource_group, name, params
)
def get_public_nic(resource_group: str, name: str) -> Optional[Any]:
logging.info("getting nic: %s %s", resource_group, name)
network_client = get_network_client()
try:
return network_client.network_interfaces.get(resource_group, name)
except (ResourceNotFoundError, CloudError):
return None
def delete_nic(resource_group: str, name: str) -> Optional[Any]:
logging.info("deleting nic %s:%s", resource_group, name)
network_client = get_network_client()
return network_client.network_interfaces.begin_delete(resource_group, name)
def create_public_nic(
resource_group: str, name: str, region: Region, nsg: Optional[NSG]
) -> Optional[Error]:
logging.info("creating nic for %s:%s in %s", resource_group, name, region)
network = Network(region)
subnet_id = network.get_id()
if subnet_id is None:
network.create()
return None
if nsg:
subnet = network.get_subnet()
if isinstance(subnet, Subnet) and not subnet.network_security_group:
result = nsg.associate_subnet(network.get_vnet(), subnet)
if isinstance(result, Error):
return result
return None
ip = get_ip(resource_group, name)
if not ip:
create_ip(resource_group, name, region)
return None
params = {
"location": region,
"ip_configurations": [
{
"name": "myIPConfig",
"public_ip_address": ip,
"subnet": {"id": subnet_id},
}
],
}
if "ONEFUZZ_OWNER" in os.environ:
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
network_client = get_network_client()
try:
network_client.network_interfaces.begin_create_or_update(
resource_group, name, params
)
except (ResourceNotFoundError, CloudError) as err:
if "RetryableError" not in repr(err):
return Error(
code=ErrorCode.VM_CREATE_FAILED,
errors=["unable to create nic: %s" % err],
)
return None
def get_public_ip(resource_id: str) -> Optional[str]:
logging.info("getting ip for %s", resource_id)
network_client = get_network_client()
resource = parse_resource_id(resource_id)
ip = (
network_client.network_interfaces.get(
resource["resource_group"], resource["name"]
)
.ip_configurations[0]
.public_ip_address
)
resource = parse_resource_id(ip.id)
ip = network_client.public_ip_addresses.get(
resource["resource_group"], resource["name"]
).ip_address
if ip is None:
return None
else:
return str(ip)

View File

@ -1,43 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from typing import Dict
from azure.mgmt.loganalytics import LogAnalyticsManagementClient
from memoization import cached
from .creds import get_base_resource_group, get_identity, get_subscription
@cached
def get_monitor_client() -> LogAnalyticsManagementClient:
return LogAnalyticsManagementClient(get_identity(), get_subscription())
@cached(ttl=60)
def get_monitor_settings() -> Dict[str, str]:
resource_group = get_base_resource_group()
workspace_name = os.environ["ONEFUZZ_MONITOR"]
client = get_monitor_client()
customer_id = client.workspaces.get(resource_group, workspace_name).customer_id
shared_key = client.shared_keys.get_shared_keys(
resource_group, workspace_name
).primary_shared_key
return {"id": customer_id, "key": shared_key}
def get_workspace_id() -> str:
# TODO:
# Once #1679 merges, we can use ONEFUZZ_MONITOR instead of ONEFUZZ_INSTANCE_NAME
workspace_id = (
"/subscriptions/%s/resourceGroups/%s/providers/microsoft.operationalinsights/workspaces/%s" # noqa: E501
% (
get_subscription(),
get_base_resource_group(),
os.environ["ONEFUZZ_INSTANCE_NAME"],
)
)
return workspace_id

View File

@ -1,9 +0,0 @@
from azure.mgmt.monitor import MonitorManagementClient
from memoization import cached
from .creds import get_identity, get_subscription
@cached
def get_monitor_client() -> MonitorManagementClient:
return MonitorManagementClient(get_identity(), get_subscription())

View File

@ -1,56 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import inspect
import logging
WORKERS_DONE = False
REDUCE_LOGGING = False
def allow_more_workers() -> None:
global WORKERS_DONE
if WORKERS_DONE:
return
stack = inspect.stack()
for entry in stack:
if entry.filename.endswith("azure_functions_worker/dispatcher.py"):
if entry.frame.f_locals["self"]._sync_call_tp._max_workers == 1:
logging.info("bumped thread worker count to 50")
entry.frame.f_locals["self"]._sync_call_tp._max_workers = 50
WORKERS_DONE = True
# TODO: Replace this with a better method for filtering out logging
# See https://github.com/Azure/azure-functions-python-worker/issues/743
def reduce_logging() -> None:
global REDUCE_LOGGING
if REDUCE_LOGGING:
return
to_quiet = [
"azure",
"cli",
"grpc",
"concurrent",
"oauthlib",
"msrest",
"opencensus",
"urllib3",
"requests",
"aiohttp",
"asyncio",
"adal-python",
]
for name in logging.Logger.manager.loggerDict:
logger = logging.getLogger(name)
for prefix in to_quiet:
if logger.name.startswith(prefix):
logger.level = logging.WARN
REDUCE_LOGGING = True

View File

@ -1,78 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import uuid
from typing import Optional, Union
from azure.mgmt.network.models import Subnet, VirtualNetwork
from msrestazure.azure_exceptions import CloudError
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error, NetworkConfig
from onefuzztypes.primitives import Region
from ..config import InstanceConfig
from .creds import get_base_resource_group
from .subnet import create_virtual_network, get_subnet, get_subnet_id, get_vnet
# This was generated randomly and should be preserved moving forwards
NETWORK_GUID_NAMESPACE = uuid.UUID("372977ad-b533-416a-b1b4-f770898e0b11")
class Network:
def __init__(self, region: Region):
self.group = get_base_resource_group()
self.region = region
self.network_config = InstanceConfig.fetch().network_config
# Network names will be calculated from the address_space/subnet
# *except* if they are the original values. This allows backwards
# compatibility to existing configs if you don't change the network
# configs.
if (
self.network_config.address_space
== NetworkConfig.__fields__["address_space"].default
and self.network_config.subnet == NetworkConfig.__fields__["subnet"].default
):
self.name: str = self.region
else:
network_id = uuid.uuid5(
NETWORK_GUID_NAMESPACE,
"|".join(
[self.network_config.address_space, self.network_config.subnet]
),
)
self.name = f"{self.region}-{network_id}"
def exists(self) -> bool:
return self.get_id() is not None
def get_id(self) -> Optional[str]:
return get_subnet_id(self.group, self.name, self.name)
def get_subnet(self) -> Optional[Subnet]:
return get_subnet(self.group, self.name, self.name)
def get_vnet(self) -> Optional[VirtualNetwork]:
return get_vnet(self.group, self.name)
def create(self) -> Union[None, Error]:
if not self.exists():
result = create_virtual_network(
self.group, self.name, self.region, self.network_config
)
if isinstance(result, CloudError):
error = Error(
code=ErrorCode.UNABLE_TO_CREATE_NETWORK, errors=[result.message]
)
logging.error(
"network creation failed: %s:%s- %s",
self.name,
self.region,
error,
)
return error
return None

View File

@ -1,9 +0,0 @@
from azure.mgmt.network import NetworkManagementClient
from memoization import cached
from .creds import get_identity, get_subscription
@cached
def get_network_client() -> NetworkManagementClient:
return NetworkManagementClient(get_identity(), get_subscription())

View File

@ -1,489 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
from typing import Dict, List, Optional, Set, Union, cast
from azure.core.exceptions import HttpResponseError, ResourceNotFoundError
from azure.mgmt.network.models import (
NetworkInterface,
NetworkSecurityGroup,
SecurityRule,
SecurityRuleAccess,
Subnet,
VirtualNetwork,
)
from msrestazure.azure_exceptions import CloudError
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error, NetworkSecurityGroupConfig
from onefuzztypes.primitives import Region
from pydantic import BaseModel, validator
from .creds import get_base_resource_group
from .network_mgmt_client import get_network_client
def is_concurrent_request_error(err: str) -> bool:
return "The request failed due to conflict with a concurrent request" in str(err)
def get_nsg(name: str) -> Optional[NetworkSecurityGroup]:
resource_group = get_base_resource_group()
logging.debug("getting nsg: %s", name)
network_client = get_network_client()
try:
nsg = network_client.network_security_groups.get(resource_group, name)
return cast(NetworkSecurityGroup, nsg)
except (ResourceNotFoundError, CloudError) as err:
logging.debug("nsg %s does not exist: %s", name, err)
return None
def create_nsg(name: str, location: Region) -> Union[None, Error]:
resource_group = get_base_resource_group()
logging.info("creating nsg %s:%s:%s", resource_group, location, name)
network_client = get_network_client()
params: Dict = {
"location": location,
}
if "ONEFUZZ_OWNER" in os.environ:
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
try:
network_client.network_security_groups.begin_create_or_update(
resource_group, name, params
)
except (ResourceNotFoundError, CloudError) as err:
if is_concurrent_request_error(str(err)):
logging.debug(
"create NSG had conflicts with concurrent request, ignoring %s", err
)
return None
return Error(
code=ErrorCode.UNABLE_TO_CREATE,
errors=["Unable to create nsg %s due to %s" % (name, err)],
)
return None
def list_nsgs() -> List[NetworkSecurityGroup]:
resource_group = get_base_resource_group()
network_client = get_network_client()
return list(network_client.network_security_groups.list(resource_group))
def update_nsg(nsg: NetworkSecurityGroup) -> Union[None, Error]:
resource_group = get_base_resource_group()
logging.info("updating nsg %s:%s:%s", resource_group, nsg.location, nsg.name)
network_client = get_network_client()
try:
network_client.network_security_groups.begin_create_or_update(
resource_group, nsg.name, nsg
)
except (ResourceNotFoundError, CloudError) as err:
if is_concurrent_request_error(str(err)):
logging.debug(
"create NSG had conflicts with concurrent request, ignoring %s", err
)
return None
return Error(
code=ErrorCode.UNABLE_TO_CREATE,
errors=["Unable to update nsg %s due to %s" % (nsg.name, err)],
)
return None
# Return True if NSG is created using OneFuzz naming convention.
# Therefore NSG belongs to OneFuzz.
def ok_to_delete(active_regions: Set[Region], nsg_region: str, nsg_name: str) -> bool:
return nsg_region not in active_regions and nsg_region == nsg_name
def is_onefuzz_nsg(nsg_region: str, nsg_name: str) -> bool:
return nsg_region == nsg_name
# Returns True if deletion completed (thus resource not found) or successfully started.
# Returns False if failed to start deletion.
def start_delete_nsg(name: str) -> bool:
# NSG can be only deleted if no other resource is associated with it
resource_group = get_base_resource_group()
logging.info("deleting nsg: %s %s", resource_group, name)
network_client = get_network_client()
try:
network_client.network_security_groups.begin_delete(resource_group, name)
return True
except HttpResponseError as err:
err_str = str(err)
if (
"cannot be deleted because it is in use by the following resources"
) in err_str:
return False
except ResourceNotFoundError:
return True
return False
def set_allowed(name: str, sources: NetworkSecurityGroupConfig) -> Union[None, Error]:
resource_group = get_base_resource_group()
nsg = get_nsg(name)
if not nsg:
return Error(
code=ErrorCode.UNABLE_TO_FIND,
errors=["cannot update nsg rules. nsg %s not found" % name],
)
logging.info(
"setting allowed incoming connection sources for nsg: %s %s",
resource_group,
name,
)
all_sources = sources.allowed_ips + sources.allowed_service_tags
security_rules = []
# NSG security rule priority range defined here:
# https://docs.microsoft.com/en-us/azure/virtual-network/network-security-groups-overview
min_priority = 100
# NSG rules per NSG limits:
# https://docs.microsoft.com/en-us/azure/azure-resource-manager/management/azure-subscription-service-limits?toc=/azure/virtual-network/toc.json#networking-limits
max_rule_count = 1000
if len(all_sources) > max_rule_count:
return Error(
code=ErrorCode.INVALID_REQUEST,
errors=[
"too many rules provided %d. Max allowed: %d"
% ((len(all_sources)), max_rule_count),
],
)
priority = min_priority
for src in all_sources:
security_rules.append(
SecurityRule(
name="Allow" + str(priority),
protocol="*",
source_port_range="*",
destination_port_range="*",
source_address_prefix=src,
destination_address_prefix="*",
access=SecurityRuleAccess.ALLOW,
priority=priority, # between 100 and 4096
direction="Inbound",
)
)
# Will not exceed `max_rule_count` or max NSG priority (4096)
# due to earlier check of `len(all_sources)`.
priority += 1
nsg.security_rules = security_rules
return update_nsg(nsg)
def clear_all_rules(name: str) -> Union[None, Error]:
return set_allowed(name, NetworkSecurityGroupConfig())
def get_all_rules(name: str) -> Union[Error, List[SecurityRule]]:
nsg = get_nsg(name)
if not nsg:
return Error(
code=ErrorCode.UNABLE_TO_FIND,
errors=["cannot get nsg rules. nsg %s not found" % name],
)
return cast(List[SecurityRule], nsg.security_rules)
def associate_nic(name: str, nic: NetworkInterface) -> Union[None, Error]:
resource_group = get_base_resource_group()
nsg = get_nsg(name)
if not nsg:
return Error(
code=ErrorCode.UNABLE_TO_FIND,
errors=["cannot associate nic. nsg %s not found" % name],
)
if nsg.location != nic.location:
return Error(
code=ErrorCode.UNABLE_TO_UPDATE,
errors=[
"network interface and nsg have to be in the same region.",
"nsg %s %s, nic: %s %s"
% (nsg.name, nsg.location, nic.name, nic.location),
],
)
if nic.network_security_group and nic.network_security_group.id == nsg.id:
logging.info(
"NIC %s and NSG %s already associated, not updating", nic.name, name
)
return None
logging.info("associating nic %s with nsg: %s %s", nic.name, resource_group, name)
nic.network_security_group = nsg
network_client = get_network_client()
try:
network_client.network_interfaces.begin_create_or_update(
resource_group, nic.name, nic
)
except (ResourceNotFoundError, CloudError) as err:
if is_concurrent_request_error(str(err)):
logging.debug(
"associate NSG with NIC had conflicts",
"with concurrent request, ignoring %s",
err,
)
return None
return Error(
code=ErrorCode.UNABLE_TO_UPDATE,
errors=[
"Unable to associate nsg %s with nic %s due to %s"
% (
name,
nic.name,
err,
)
],
)
return None
def dissociate_nic(name: str, nic: NetworkInterface) -> Union[None, Error]:
if nic.network_security_group is None:
return None
resource_group = get_base_resource_group()
nsg = get_nsg(name)
if not nsg:
return Error(
code=ErrorCode.UNABLE_TO_FIND,
errors=["cannot update nsg rules. nsg %s not found" % name],
)
if nsg.id != nic.network_security_group.id:
return Error(
code=ErrorCode.UNABLE_TO_UPDATE,
errors=[
"network interface is not associated with this nsg.",
"nsg %s, nic: %s, nic.nsg: %s"
% (
nsg.id,
nic.name,
nic.network_security_group.id,
),
],
)
logging.info("dissociating nic %s with nsg: %s %s", nic.name, resource_group, name)
nic.network_security_group = None
network_client = get_network_client()
try:
network_client.network_interfaces.begin_create_or_update(
resource_group, nic.name, nic
)
except (ResourceNotFoundError, CloudError) as err:
if is_concurrent_request_error(str(err)):
logging.debug(
"dissociate nsg with nic had conflicts with ",
"concurrent request, ignoring %s",
err,
)
return None
return Error(
code=ErrorCode.UNABLE_TO_UPDATE,
errors=[
"Unable to dissociate nsg %s with nic %s due to %s"
% (
name,
nic.name,
err,
)
],
)
return None
def associate_subnet(
name: str, vnet: VirtualNetwork, subnet: Subnet
) -> Union[None, Error]:
resource_group = get_base_resource_group()
nsg = get_nsg(name)
if not nsg:
return Error(
code=ErrorCode.UNABLE_TO_FIND,
errors=["cannot associate subnet. nsg %s not found" % name],
)
if nsg.location != vnet.location:
return Error(
code=ErrorCode.UNABLE_TO_UPDATE,
errors=[
"subnet and nsg have to be in the same region.",
"nsg %s %s, subnet: %s %s"
% (nsg.name, nsg.location, subnet.name, subnet.location),
],
)
if subnet.network_security_group and subnet.network_security_group.id == nsg.id:
logging.info(
"Subnet %s and NSG %s already associated, not updating", subnet.name, name
)
return None
logging.info(
"associating subnet %s with nsg: %s %s", subnet.name, resource_group, name
)
subnet.network_security_group = nsg
network_client = get_network_client()
try:
network_client.subnets.begin_create_or_update(
resource_group, vnet.name, subnet.name, subnet
)
except (ResourceNotFoundError, CloudError) as err:
if is_concurrent_request_error(str(err)):
logging.debug(
"associate NSG with subnet had conflicts",
"with concurrent request, ignoring %s",
err,
)
return None
return Error(
code=ErrorCode.UNABLE_TO_UPDATE,
errors=[
"Unable to associate nsg %s with subnet %s due to %s"
% (
name,
subnet.name,
err,
)
],
)
return None
def dissociate_subnet(
name: str, vnet: VirtualNetwork, subnet: Subnet
) -> Union[None, Error]:
if subnet.network_security_group is None:
return None
resource_group = get_base_resource_group()
nsg = get_nsg(name)
if not nsg:
return Error(
code=ErrorCode.UNABLE_TO_FIND,
errors=["cannot update nsg rules. nsg %s not found" % name],
)
if nsg.id != subnet.network_security_group.id:
return Error(
code=ErrorCode.UNABLE_TO_UPDATE,
errors=[
"subnet is not associated with this nsg.",
"nsg %s, subnet: %s, subnet.nsg: %s"
% (
nsg.id,
subnet.name,
subnet.network_security_group.id,
),
],
)
logging.info(
"dissociating subnet %s with nsg: %s %s", subnet.name, resource_group, name
)
subnet.network_security_group = None
network_client = get_network_client()
try:
network_client.subnets.begin_create_or_update(
resource_group, vnet.name, subnet.name, subnet
)
except (ResourceNotFoundError, CloudError) as err:
if is_concurrent_request_error(str(err)):
logging.debug(
"dissociate nsg with subnet had conflicts with ",
"concurrent request, ignoring %s",
err,
)
return None
return Error(
code=ErrorCode.UNABLE_TO_UPDATE,
errors=[
"Unable to dissociate nsg %s with subnet %s due to %s"
% (
name,
subnet.name,
err,
)
],
)
return None
class NSG(BaseModel):
name: str
region: Region
@validator("name", allow_reuse=True)
def check_name(cls, value: str) -> str:
# https://docs.microsoft.com/en-us/azure/azure-resource-manager/management/resource-name-rules
if len(value) > 80:
raise ValueError("NSG name too long")
return value
def create(self) -> Union[None, Error]:
# Optimization: if NSG exists - do not try
# to create it
if self.get() is not None:
return None
return create_nsg(self.name, self.region)
def start_delete(self) -> bool:
return start_delete_nsg(self.name)
def get(self) -> Optional[NetworkSecurityGroup]:
return get_nsg(self.name)
def set_allowed_sources(
self, sources: NetworkSecurityGroupConfig
) -> Union[None, Error]:
return set_allowed(self.name, sources)
def clear_all_rules(self) -> Union[None, Error]:
return clear_all_rules(self.name)
def get_all_rules(self) -> Union[Error, List[SecurityRule]]:
return get_all_rules(self.name)
def associate_nic(self, nic: NetworkInterface) -> Union[None, Error]:
return associate_nic(self.name, nic)
def dissociate_nic(self, nic: NetworkInterface) -> Union[None, Error]:
return dissociate_nic(self.name, nic)
def associate_subnet(
self, vnet: VirtualNetwork, subnet: Subnet
) -> Union[None, Error]:
return associate_subnet(self.name, vnet, subnet)
def dissociate_subnet(
self, vnet: VirtualNetwork, subnet: Subnet
) -> Union[None, Error]:
return dissociate_subnet(self.name, vnet, subnet)

View File

@ -1,203 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import base64
import datetime
import json
import logging
from typing import List, Optional, Type, TypeVar, Union
from uuid import UUID
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
from azure.storage.queue import (
QueueSasPermissions,
QueueServiceClient,
generate_queue_sas,
)
from memoization import cached
from pydantic import BaseModel
from .storage import StorageType, get_primary_account, get_storage_account_name_key
QueueNameType = Union[str, UUID]
DEFAULT_TTL = -1
DEFAULT_DURATION = datetime.timedelta(days=30)
@cached(ttl=60)
def get_queue_client(storage_type: StorageType) -> QueueServiceClient:
account_id = get_primary_account(storage_type)
logging.debug("getting blob container (account_id: %s)", account_id)
name, key = get_storage_account_name_key(account_id)
account_url = "https://%s.queue.core.windows.net" % name
client = QueueServiceClient(
account_url=account_url,
credential={"account_name": name, "account_key": key},
)
return client
@cached(ttl=60)
def get_queue_sas(
queue: QueueNameType,
storage_type: StorageType,
*,
read: bool = False,
add: bool = False,
update: bool = False,
process: bool = False,
duration: Optional[datetime.timedelta] = None,
) -> str:
if duration is None:
duration = DEFAULT_DURATION
account_id = get_primary_account(storage_type)
logging.debug("getting queue sas %s (account_id: %s)", queue, account_id)
name, key = get_storage_account_name_key(account_id)
expiry = datetime.datetime.utcnow() + duration
token = generate_queue_sas(
name,
str(queue),
key,
permission=QueueSasPermissions(
read=read, add=add, update=update, process=process
),
expiry=expiry,
)
url = "https://%s.queue.core.windows.net/%s?%s" % (name, queue, token)
return url
@cached(ttl=60)
def create_queue(name: QueueNameType, storage_type: StorageType) -> None:
client = get_queue_client(storage_type)
try:
client.create_queue(str(name))
except ResourceExistsError:
pass
def delete_queue(name: QueueNameType, storage_type: StorageType) -> None:
client = get_queue_client(storage_type)
queues = client.list_queues()
if str(name) in [x["name"] for x in queues]:
client.delete_queue(name)
def get_queue(
name: QueueNameType, storage_type: StorageType
) -> Optional[QueueServiceClient]:
client = get_queue_client(storage_type)
try:
return client.get_queue_client(str(name))
except ResourceNotFoundError:
return None
def clear_queue(name: QueueNameType, storage_type: StorageType) -> None:
queue = get_queue(name, storage_type)
if queue:
try:
queue.clear_messages()
except ResourceNotFoundError:
pass
def send_message(
name: QueueNameType,
message: bytes,
storage_type: StorageType,
*,
visibility_timeout: Optional[int] = None,
time_to_live: int = DEFAULT_TTL,
) -> None:
queue = get_queue(name, storage_type)
if queue:
try:
queue.send_message(
base64.b64encode(message).decode(),
visibility_timeout=visibility_timeout,
time_to_live=time_to_live,
)
except ResourceNotFoundError:
pass
def remove_first_message(name: QueueNameType, storage_type: StorageType) -> bool:
queue = get_queue(name, storage_type)
if queue:
try:
for message in queue.receive_messages():
queue.delete_message(message)
return True
except ResourceNotFoundError:
return False
return False
A = TypeVar("A", bound=BaseModel)
MIN_PEEK_SIZE = 1
MAX_PEEK_SIZE = 32
# Peek at a max of 32 messages
# https://docs.microsoft.com/en-us/python/api/azure-storage-queue/azure.storage.queue.queueclient
def peek_queue(
name: QueueNameType,
storage_type: StorageType,
*,
object_type: Type[A],
max_messages: int = MAX_PEEK_SIZE,
) -> List[A]:
result: List[A] = []
# message count
if max_messages < MIN_PEEK_SIZE or max_messages > MAX_PEEK_SIZE:
raise ValueError("invalid max messages: %s" % max_messages)
try:
queue = get_queue(name, storage_type)
if not queue:
return result
for message in queue.peek_messages(max_messages=max_messages):
decoded = base64.b64decode(message.content)
raw = json.loads(decoded)
result.append(object_type.parse_obj(raw))
except ResourceNotFoundError:
return result
return result
def queue_object(
name: QueueNameType,
message: BaseModel,
storage_type: StorageType,
*,
visibility_timeout: Optional[int] = None,
time_to_live: int = DEFAULT_TTL,
) -> bool:
queue = get_queue(name, storage_type)
if not queue:
raise Exception("unable to queue object, no such queue: %s" % queue)
encoded = base64.b64encode(message.json(exclude_none=True).encode()).decode()
try:
queue.send_message(
encoded, visibility_timeout=visibility_timeout, time_to_live=time_to_live
)
return True
except ResourceNotFoundError:
return False
def get_resource_id(queue_name: QueueNameType, storage_type: StorageType) -> str:
account_id = get_primary_account(storage_type)
resource_uri = "%s/services/queue/queues/%s" % (account_id, queue_name)
return resource_uri

View File

@ -1,120 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
import random
from enum import Enum
from typing import List, Tuple, cast
from azure.mgmt.storage import StorageManagementClient
from memoization import cached
from msrestazure.tools import parse_resource_id
from .creds import get_base_resource_group, get_identity, get_subscription
class StorageType(Enum):
corpus = "corpus"
config = "config"
@cached
def get_mgmt_client() -> StorageManagementClient:
return StorageManagementClient(
credential=get_identity(), subscription_id=get_subscription()
)
@cached
def get_fuzz_storage() -> str:
return os.environ["ONEFUZZ_DATA_STORAGE"]
@cached
def get_func_storage() -> str:
return os.environ["ONEFUZZ_FUNC_STORAGE"]
@cached
def get_primary_account(storage_type: StorageType) -> str:
if storage_type == StorageType.corpus:
# see #322 for discussion about typing
return get_fuzz_storage()
elif storage_type == StorageType.config:
# see #322 for discussion about typing
return get_func_storage()
raise NotImplementedError
@cached
def get_accounts(storage_type: StorageType) -> List[str]:
if storage_type == StorageType.corpus:
return corpus_accounts()
elif storage_type == StorageType.config:
return [get_func_storage()]
else:
raise NotImplementedError
@cached
def get_storage_account_name_key(account_id: str) -> Tuple[str, str]:
resource = parse_resource_id(account_id)
key = get_storage_account_name_key_by_name(resource["name"])
return resource["name"], key
@cached
def get_storage_account_name_key_by_name(account_name: str) -> str:
client = get_mgmt_client()
group = get_base_resource_group()
key = client.storage_accounts.list_keys(group, account_name).keys[0].value
return cast(str, key)
def choose_account(storage_type: StorageType) -> str:
accounts = get_accounts(storage_type)
if not accounts:
raise Exception(f"no storage accounts for {storage_type}")
if len(accounts) == 1:
return accounts[0]
# Use a random secondary storage account if any are available. This
# reduces IOP contention for the Storage Queues, which are only available
# on primary accounts
#
# security note: this is not used as a security feature
return random.choice(accounts[1:]) # nosec
@cached
def corpus_accounts() -> List[str]:
skip = get_func_storage()
results = [get_fuzz_storage()]
client = get_mgmt_client()
group = get_base_resource_group()
for account in client.storage_accounts.list_by_resource_group(group):
# protection from someone adding the corpus tag to the config account
if account.id == skip:
continue
if account.id in results:
continue
if account.primary_endpoints.blob is None:
continue
if (
"storage_type" not in account.tags
or account.tags["storage_type"] != "corpus"
):
continue
results.append(account.id)
logging.info("corpus accounts: %s", results)
return results

View File

@ -1,97 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
from typing import Any, Optional, Union, cast
from azure.core.exceptions import ResourceNotFoundError
from azure.mgmt.network.models import Subnet, VirtualNetwork
from msrestazure.azure_exceptions import CloudError
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error, NetworkConfig
from onefuzztypes.primitives import Region
from .network_mgmt_client import get_network_client
def get_vnet(resource_group: str, name: str) -> Optional[VirtualNetwork]:
network_client = get_network_client()
try:
vnet = network_client.virtual_networks.get(resource_group, name)
return cast(VirtualNetwork, vnet)
except (CloudError, ResourceNotFoundError):
logging.info(
"vnet missing: resource group:%s name:%s",
resource_group,
name,
)
return None
def get_subnet(
resource_group: str, vnet_name: str, subnet_name: str
) -> Optional[Subnet]:
# Has to get using vnet. That way NSG field is properly set up in subnet
vnet = get_vnet(resource_group, vnet_name)
if vnet:
for subnet in vnet.subnets:
if subnet.name == subnet_name:
return subnet
return None
def get_subnet_id(resource_group: str, name: str, subnet_name: str) -> Optional[str]:
subnet = get_subnet(resource_group, name, subnet_name)
if subnet and isinstance(subnet.id, str):
return subnet.id
else:
return None
def delete_subnet(resource_group: str, name: str) -> Union[None, CloudError, Any]:
network_client = get_network_client()
try:
return network_client.virtual_networks.begin_delete(resource_group, name)
except (CloudError, ResourceNotFoundError) as err:
if err.error and "InUseSubnetCannotBeDeleted" in str(err.error):
logging.error(
"subnet delete failed: %s %s : %s", resource_group, name, repr(err)
)
return None
raise err
def create_virtual_network(
resource_group: str,
name: str,
region: Region,
network_config: NetworkConfig,
) -> Optional[Error]:
logging.info(
"creating subnet - resource group:%s name:%s region:%s",
resource_group,
name,
region,
)
network_client = get_network_client()
params = {
"location": region,
"address_space": {"address_prefixes": [network_config.address_space]},
"subnets": [{"name": name, "address_prefix": network_config.subnet}],
}
if "ONEFUZZ_OWNER" in os.environ:
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
try:
network_client.virtual_networks.begin_create_or_update(
resource_group, name, params
)
except (CloudError, ResourceNotFoundError) as err:
return Error(code=ErrorCode.UNABLE_TO_CREATE_NETWORK, errors=[str(err)])
return None

View File

@ -1,30 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
from typing import Optional
from azure.cosmosdb.table import TableService
from memoization import cached
from .storage import get_storage_account_name_key
@cached(ttl=60)
def get_client(
table: Optional[str] = None, account_id: Optional[str] = None
) -> TableService:
if account_id is None:
account_id = os.environ["ONEFUZZ_FUNC_STORAGE"]
logging.debug("getting table account: (account_id: %s)", account_id)
name, key = get_storage_account_name_key(account_id)
client = TableService(account_name=name, account_key=key)
if table and not client.exists(table):
logging.info("creating missing table %s", table)
client.create_table(table, fail_on_exist=False)
return client

View File

@ -1,313 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
from typing import Any, Dict, List, Optional, Union, cast
from uuid import UUID
from azure.core.exceptions import ResourceNotFoundError
from azure.mgmt.compute.models import VirtualMachine
from msrestazure.azure_exceptions import CloudError
from onefuzztypes.enums import OS, ErrorCode
from onefuzztypes.models import Authentication, Error
from onefuzztypes.primitives import Extension, Region
from pydantic import BaseModel, validator
from .compute import get_compute_client
from .creds import get_base_resource_group
from .disk import delete_disk, list_disks
from .image import get_os
from .ip import create_public_nic, delete_ip, delete_nic, get_ip, get_public_nic
from .nsg import NSG
def get_vm(name: str) -> Optional[VirtualMachine]:
resource_group = get_base_resource_group()
logging.debug("getting vm: %s", name)
compute_client = get_compute_client()
try:
return cast(
VirtualMachine,
compute_client.virtual_machines.get(
resource_group, name, expand="instanceView"
),
)
except (ResourceNotFoundError, CloudError) as err:
logging.debug("vm does not exist %s", err)
return None
def create_vm(
name: str,
location: Region,
vm_sku: str,
image: str,
password: str,
ssh_public_key: str,
nsg: Optional[NSG],
tags: Optional[Dict[str, str]],
) -> Union[None, Error]:
resource_group = get_base_resource_group()
logging.info("creating vm %s:%s:%s", resource_group, location, name)
compute_client = get_compute_client()
nic = get_public_nic(resource_group, name)
if nic is None:
result = create_public_nic(resource_group, name, location, nsg)
if isinstance(result, Error):
return result
logging.info("waiting on nic creation")
return None
# when public nic is created, VNET must exist at that point
# this is logic of get_public_nic function
if nsg:
result = nsg.associate_nic(nic)
if isinstance(result, Error):
return result
if image.startswith("/"):
image_ref = {"id": image}
else:
image_val = image.split(":", 4)
image_ref = {
"publisher": image_val[0],
"offer": image_val[1],
"sku": image_val[2],
"version": image_val[3],
}
params: Dict = {
"location": location,
"os_profile": {
"computer_name": "node",
"admin_username": "onefuzz",
},
"hardware_profile": {"vm_size": vm_sku},
"storage_profile": {"image_reference": image_ref},
"network_profile": {"network_interfaces": [{"id": nic.id}]},
}
image_os = get_os(location, image)
if isinstance(image_os, Error):
return image_os
if image_os == OS.windows:
params["os_profile"]["admin_password"] = password
if image_os == OS.linux:
params["os_profile"]["linux_configuration"] = {
"disable_password_authentication": True,
"ssh": {
"public_keys": [
{
"path": "/home/onefuzz/.ssh/authorized_keys",
"key_data": ssh_public_key,
}
]
},
}
if "ONEFUZZ_OWNER" in os.environ:
params["tags"] = {"OWNER": os.environ["ONEFUZZ_OWNER"]}
if tags:
params["tags"].update(tags.copy())
try:
compute_client.virtual_machines.begin_create_or_update(
resource_group, name, params
)
except (ResourceNotFoundError, CloudError) as err:
if "The request failed due to conflict with a concurrent request" in str(err):
logging.debug(
"create VM had conflicts with concurrent request, ignoring %s", err
)
return None
return Error(code=ErrorCode.VM_CREATE_FAILED, errors=[str(err)])
return None
def get_extension(vm_name: str, extension_name: str) -> Optional[Any]:
resource_group = get_base_resource_group()
logging.debug(
"getting extension: %s:%s:%s",
resource_group,
vm_name,
extension_name,
)
compute_client = get_compute_client()
try:
return compute_client.virtual_machine_extensions.get(
resource_group, vm_name, extension_name
)
except (ResourceNotFoundError, CloudError) as err:
logging.info("extension does not exist %s", err)
return None
def create_extension(vm_name: str, extension: Dict) -> Any:
resource_group = get_base_resource_group()
logging.info(
"creating extension: %s:%s:%s", resource_group, vm_name, extension["name"]
)
compute_client = get_compute_client()
return compute_client.virtual_machine_extensions.begin_create_or_update(
resource_group, vm_name, extension["name"], extension
)
def delete_vm(name: str) -> Any:
resource_group = get_base_resource_group()
logging.info("deleting vm: %s %s", resource_group, name)
compute_client = get_compute_client()
return compute_client.virtual_machines.begin_delete(resource_group, name)
def has_components(name: str) -> bool:
# check if any of the components associated with a VM still exist.
#
# Azure VM Deletion requires we first delete the VM, then delete all of it's
# resources. This is required to ensure we've cleaned it all up before
# marking it "done"
resource_group = get_base_resource_group()
if get_vm(name):
return True
if get_public_nic(resource_group, name):
return True
if get_ip(resource_group, name):
return True
disks = [x.name for x in list_disks(resource_group) if x.name.startswith(name)]
if disks:
return True
return False
def delete_vm_components(name: str, nsg: Optional[NSG]) -> bool:
resource_group = get_base_resource_group()
logging.info("deleting vm components %s:%s", resource_group, name)
if get_vm(name):
logging.info("deleting vm %s:%s", resource_group, name)
delete_vm(name)
return False
nic = get_public_nic(resource_group, name)
if nic:
logging.info("deleting nic %s:%s", resource_group, name)
if nic.network_security_group and nsg:
nsg.dissociate_nic(nic)
return False
delete_nic(resource_group, name)
return False
if get_ip(resource_group, name):
logging.info("deleting ip %s:%s", resource_group, name)
delete_ip(resource_group, name)
return False
disks = [x.name for x in list_disks(resource_group) if x.name.startswith(name)]
if disks:
for disk in disks:
logging.info("deleting disk %s:%s", resource_group, disk)
delete_disk(resource_group, disk)
return False
return True
class VM(BaseModel):
name: Union[UUID, str]
region: Region
sku: str
image: str
auth: Authentication
nsg: Optional[NSG]
tags: Optional[Dict[str, str]]
@validator("name", allow_reuse=True)
def check_name(cls, value: Union[UUID, str]) -> Union[UUID, str]:
if isinstance(value, str):
if len(value) > 40:
# Azure truncates resources if the names are longer than 40
# bytes
raise ValueError("VM name too long")
return value
def is_deleted(self) -> bool:
# A VM is considered deleted once all of it's resources including disks,
# NICs, IPs, as well as the VM are deleted
return not has_components(str(self.name))
def exists(self) -> bool:
return self.get() is not None
def get(self) -> Optional[VirtualMachine]:
return get_vm(str(self.name))
def create(self) -> Union[None, Error]:
if self.get() is not None:
return None
logging.info("vm creating: %s", self.name)
return create_vm(
str(self.name),
self.region,
self.sku,
self.image,
self.auth.password,
self.auth.public_key,
self.nsg,
self.tags,
)
def delete(self) -> bool:
return delete_vm_components(str(self.name), self.nsg)
def add_extensions(self, extensions: List[Extension]) -> Union[bool, Error]:
status = []
to_create = []
for config in extensions:
if not isinstance(config["name"], str):
logging.error("vm agent - incompatable name: %s", repr(config))
continue
extension = get_extension(str(self.name), config["name"])
if extension:
logging.info(
"vm extension state: %s - %s - %s",
self.name,
config["name"],
extension.provisioning_state,
)
status.append(extension.provisioning_state)
else:
to_create.append(config)
if to_create:
for config in to_create:
create_extension(str(self.name), config)
else:
if all([x == "Succeeded" for x in status]):
return True
elif "Failed" in status:
return Error(
code=ErrorCode.VM_CREATE_FAILED,
errors=["failed to launch extension"],
)
elif not ("Creating" in status or "Updating" in status):
logging.error("vm agent - unknown state %s: %s", self.name, status)
return False

View File

@ -1,541 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Set, Union, cast
from uuid import UUID
from azure.core.exceptions import (
HttpResponseError,
ResourceExistsError,
ResourceNotFoundError,
)
from azure.mgmt.compute.models import (
ResourceSku,
ResourceSkuRestrictionsType,
VirtualMachineScaleSetVMInstanceIDs,
VirtualMachineScaleSetVMInstanceRequiredIDs,
VirtualMachineScaleSetVMListResult,
VirtualMachineScaleSetVMProtectionPolicy,
)
from memoization import cached
from msrestazure.azure_exceptions import CloudError
from onefuzztypes.enums import OS, ErrorCode
from onefuzztypes.models import Error
from onefuzztypes.primitives import Region
from .compute import get_compute_client
from .creds import (
get_base_resource_group,
get_scaleset_identity_resource_path,
retry_on_auth_failure,
)
from .image import get_os
@retry_on_auth_failure()
def list_vmss(
name: UUID,
vm_filter: Optional[Callable[[VirtualMachineScaleSetVMListResult], bool]] = None,
) -> Optional[List[str]]:
resource_group = get_base_resource_group()
client = get_compute_client()
try:
instances = [
x.instance_id
for x in client.virtual_machine_scale_set_vms.list(
resource_group, str(name)
)
if vm_filter is None or vm_filter(x)
]
return instances
except (ResourceNotFoundError, CloudError) as err:
logging.error("cloud error listing vmss: %s (%s)", name, err)
return None
@retry_on_auth_failure()
def delete_vmss(name: UUID) -> bool:
resource_group = get_base_resource_group()
compute_client = get_compute_client()
response = compute_client.virtual_machine_scale_sets.begin_delete(
resource_group, str(name)
)
# https://docs.microsoft.com/en-us/python/api/azure-core/
# azure.core.polling.lropoller?view=azure-python#status--
#
# status returns a str, however mypy thinks this is an Any.
#
# Checked by hand that the result is Succeeded in practice
return bool(response.status() == "Succeeded")
@retry_on_auth_failure()
def get_vmss(name: UUID) -> Optional[Any]:
resource_group = get_base_resource_group()
logging.debug("getting vm: %s", name)
compute_client = get_compute_client()
try:
return compute_client.virtual_machine_scale_sets.get(resource_group, str(name))
except ResourceNotFoundError as err:
logging.debug("vm does not exist %s", err)
return None
@retry_on_auth_failure()
def resize_vmss(name: UUID, capacity: int) -> None:
check_can_update(name)
resource_group = get_base_resource_group()
logging.info("updating VM count - name: %s vm_count: %d", name, capacity)
compute_client = get_compute_client()
try:
compute_client.virtual_machine_scale_sets.begin_update(
resource_group, str(name), {"sku": {"capacity": capacity}}
)
except ResourceExistsError as err:
logging.error(
"unable to resize scaleset. name:%s vm_count:%d - err:%s",
name,
capacity,
err,
)
except HttpResponseError as err:
if (
"models that may be referenced by one or more"
+ " VMs belonging to the Virtual Machine Scale Set"
in str(err)
):
logging.error(
"unable to resize scaleset due to model error. name: %s - err:%s",
name,
err,
)
@retry_on_auth_failure()
def get_vmss_size(name: UUID) -> Optional[int]:
vmss = get_vmss(name)
if vmss is None:
return None
return cast(int, vmss.sku.capacity)
@retry_on_auth_failure()
def list_instance_ids(name: UUID) -> Dict[UUID, str]:
logging.debug("get instance IDs for scaleset: %s", name)
resource_group = get_base_resource_group()
compute_client = get_compute_client()
results = {}
try:
for instance in compute_client.virtual_machine_scale_set_vms.list(
resource_group, str(name)
):
results[UUID(instance.vm_id)] = cast(str, instance.instance_id)
except (ResourceNotFoundError, CloudError):
logging.debug("vm does not exist %s", name)
return results
@cached(ttl=60)
@retry_on_auth_failure()
def get_instance_id(name: UUID, vm_id: UUID) -> Union[str, Error]:
resource_group = get_base_resource_group()
logging.info("get instance ID for scaleset node: %s:%s", name, vm_id)
compute_client = get_compute_client()
vm_id_str = str(vm_id)
for instance in compute_client.virtual_machine_scale_set_vms.list(
resource_group, str(name)
):
if instance.vm_id == vm_id_str:
return cast(str, instance.instance_id)
return Error(
code=ErrorCode.UNABLE_TO_FIND,
errors=["unable to find scaleset machine: %s:%s" % (name, vm_id)],
)
@retry_on_auth_failure()
def update_scale_in_protection(
name: UUID, vm_id: UUID, protect_from_scale_in: bool
) -> Optional[Error]:
instance_id = get_instance_id(name, vm_id)
if isinstance(instance_id, Error):
return instance_id
compute_client = get_compute_client()
resource_group = get_base_resource_group()
try:
instance_vm = compute_client.virtual_machine_scale_set_vms.get(
resource_group, name, instance_id
)
except (ResourceNotFoundError, CloudError):
return Error(
code=ErrorCode.UNABLE_TO_FIND,
errors=["unable to find vm instance: %s:%s" % (name, instance_id)],
)
new_protection_policy = VirtualMachineScaleSetVMProtectionPolicy(
protect_from_scale_in=protect_from_scale_in
)
if instance_vm.protection_policy is not None:
new_protection_policy = instance_vm.protection_policy
new_protection_policy.protect_from_scale_in = protect_from_scale_in
instance_vm.protection_policy = new_protection_policy
try:
compute_client.virtual_machine_scale_set_vms.begin_update(
resource_group, name, instance_id, instance_vm
)
except (ResourceNotFoundError, CloudError, HttpResponseError) as err:
if isinstance(err, HttpResponseError):
err_str = str(err)
instance_not_found = (
" is not an active Virtual Machine Scale Set VM instanceId."
)
if (
instance_not_found in err_str
and instance_vm.protection_policy.protect_from_scale_in is False
and protect_from_scale_in
== instance_vm.protection_policy.protect_from_scale_in
):
logging.info(
"Tried to remove scale in protection on node %s but the instance no longer exists" # noqa: E501
% instance_id
)
return None
if (
"models that may be referenced by one or more"
+ " VMs belonging to the Virtual Machine Scale Set"
in str(err)
):
logging.error(
"unable to resize scaleset due to model error. name: %s - err:%s",
name,
err,
)
return None
return Error(
code=ErrorCode.UNABLE_TO_UPDATE,
errors=["unable to set protection policy on: %s:%s" % (vm_id, instance_id)],
)
logging.info(
"Successfully set scale in protection on node %s to %s"
% (vm_id, protect_from_scale_in)
)
return None
class UnableToUpdate(Exception):
pass
def check_can_update(name: UUID) -> Any:
vmss = get_vmss(name)
if vmss is None:
raise UnableToUpdate
if vmss.provisioning_state == "Updating":
raise UnableToUpdate
return vmss
@retry_on_auth_failure()
def reimage_vmss_nodes(name: UUID, vm_ids: Set[UUID]) -> Optional[Error]:
check_can_update(name)
resource_group = get_base_resource_group()
compute_client = get_compute_client()
instance_ids = set()
machine_to_id = list_instance_ids(name)
for vm_id in vm_ids:
if vm_id in machine_to_id:
instance_ids.add(machine_to_id[vm_id])
else:
logging.info("unable to find vm_id for %s:%s", name, vm_id)
# Nodes that must be are 'upgraded' before the reimage. This call makes sure
# the instance is up-to-date with the VMSS model.
# The expectation is that these requests are queued and handled subsequently.
# The VMSS Team confirmed this expectation and testing supports it, as well.
if instance_ids:
logging.info("upgrading VMSS nodes - name: %s vm_ids: %s", name, vm_id)
compute_client.virtual_machine_scale_sets.begin_update_instances(
resource_group,
str(name),
VirtualMachineScaleSetVMInstanceIDs(instance_ids=list(instance_ids)),
)
logging.info("reimaging VMSS nodes - name: %s vm_ids: %s", name, vm_id)
compute_client.virtual_machine_scale_sets.begin_reimage_all(
resource_group,
str(name),
VirtualMachineScaleSetVMInstanceIDs(instance_ids=list(instance_ids)),
)
return None
@retry_on_auth_failure()
def delete_vmss_nodes(name: UUID, vm_ids: Set[UUID]) -> Optional[Error]:
check_can_update(name)
resource_group = get_base_resource_group()
logging.info("deleting scaleset VM - name: %s vm_ids:%s", name, vm_ids)
compute_client = get_compute_client()
instance_ids = set()
machine_to_id = list_instance_ids(name)
for vm_id in vm_ids:
if vm_id in machine_to_id:
instance_ids.add(machine_to_id[vm_id])
else:
logging.info("unable to find vm_id for %s:%s", name, vm_id)
if instance_ids:
compute_client.virtual_machine_scale_sets.begin_delete_instances(
resource_group,
str(name),
VirtualMachineScaleSetVMInstanceRequiredIDs(
instance_ids=list(instance_ids)
),
)
return None
@retry_on_auth_failure()
def update_extensions(name: UUID, extensions: List[Any]) -> None:
check_can_update(name)
resource_group = get_base_resource_group()
logging.info("updating VM extensions: %s", name)
compute_client = get_compute_client()
try:
compute_client.virtual_machine_scale_sets.begin_update(
resource_group,
str(name),
{
"virtual_machine_profile": {
"extension_profile": {"extensions": extensions}
}
},
)
logging.info("VM extensions updated: %s", name)
except HttpResponseError as err:
if (
"models that may be referenced by one or more"
+ " VMs belonging to the Virtual Machine Scale Set"
in str(err)
):
logging.error(
"unable to resize scaleset due to model error. name: %s - err:%s",
name,
err,
)
@retry_on_auth_failure()
def create_vmss(
location: Region,
name: UUID,
vm_sku: str,
vm_count: int,
image: str,
network_id: str,
spot_instances: bool,
ephemeral_os_disks: bool,
extensions: List[Any],
password: str,
ssh_public_key: str,
tags: Dict[str, str],
) -> Optional[Error]:
vmss = get_vmss(name)
if vmss is not None:
return None
logging.info(
"creating VM "
"name: %s vm_sku: %s vm_count: %d "
"image: %s subnet: %s spot_instances: %s",
name,
vm_sku,
vm_count,
image,
network_id,
spot_instances,
)
resource_group = get_base_resource_group()
compute_client = get_compute_client()
if image.startswith("/"):
image_ref = {"id": image}
else:
image_val = image.split(":", 4)
image_ref = {
"publisher": image_val[0],
"offer": image_val[1],
"sku": image_val[2],
"version": image_val[3],
}
sku = {"name": vm_sku, "tier": "Standard", "capacity": vm_count}
params: Dict[str, Any] = {
"location": location,
"do_not_run_extensions_on_overprovisioned_vms": True,
"upgrade_policy": {"mode": "Manual"},
"sku": sku,
"overprovision": False,
"identity": {
"type": "userAssigned",
"userAssignedIdentities": {get_scaleset_identity_resource_path(): {}},
},
"virtual_machine_profile": {
"priority": "Regular",
"storage_profile": {
"image_reference": image_ref,
},
"os_profile": {
"computer_name_prefix": "node",
"admin_username": "onefuzz",
},
"network_profile": {
"network_interface_configurations": [
{
"name": "onefuzz-nic",
"primary": True,
"ip_configurations": [
{"name": "onefuzz-ip-config", "subnet": {"id": network_id}}
],
}
]
},
"extension_profile": {"extensions": extensions},
},
"single_placement_group": False,
}
image_os = get_os(location, image)
if isinstance(image_os, Error):
return image_os
if image_os == OS.windows:
params["virtual_machine_profile"]["os_profile"]["admin_password"] = password
if image_os == OS.linux:
params["virtual_machine_profile"]["os_profile"]["linux_configuration"] = {
"disable_password_authentication": True,
"ssh": {
"public_keys": [
{
"path": "/home/onefuzz/.ssh/authorized_keys",
"key_data": ssh_public_key,
}
]
},
}
if ephemeral_os_disks:
params["virtual_machine_profile"]["storage_profile"]["os_disk"] = {
"diffDiskSettings": {"option": "Local"},
"caching": "ReadOnly",
"createOption": "FromImage",
}
if spot_instances:
# Setting max price to -1 means it won't be evicted because of
# price.
#
# https://docs.microsoft.com/en-us/azure/
# virtual-machine-scale-sets/use-spot#resource-manager-templates
params["virtual_machine_profile"].update(
{
"eviction_policy": "Delete",
"priority": "Spot",
"billing_profile": {"max_price": -1},
}
)
params["tags"] = tags.copy()
owner = os.environ.get("ONEFUZZ_OWNER")
if owner:
params["tags"]["OWNER"] = owner
try:
compute_client.virtual_machine_scale_sets.begin_create_or_update(
resource_group, name, params
)
except ResourceExistsError as err:
err_str = str(err)
if "SkuNotAvailable" in err_str or "OperationNotAllowed" in err_str:
return Error(
code=ErrorCode.VM_CREATE_FAILED, errors=[f"creating vmss: {err_str}"]
)
raise err
except (ResourceNotFoundError, CloudError) as err:
if "The request failed due to conflict with a concurrent request" in repr(err):
logging.debug(
"create VM had conflicts with concurrent request, ignoring %s", err
)
return None
return Error(
code=ErrorCode.VM_CREATE_FAILED,
errors=["creating vmss: %s" % err],
)
except HttpResponseError as err:
err_str = str(err)
# Catch Gen2 Hypervisor / Image mismatch errors
# See https://github.com/microsoft/lisa/pull/716 for an example
if (
"check that the Hypervisor Generation of the Image matches the "
"Hypervisor Generation of the selected VM Size"
) in err_str:
return Error(
code=ErrorCode.VM_CREATE_FAILED, errors=[f"creating vmss: {err_str}"]
)
raise err
return None
@cached(ttl=60)
@retry_on_auth_failure()
def list_available_skus(region: Region) -> List[str]:
compute_client = get_compute_client()
skus: List[ResourceSku] = list(
compute_client.resource_skus.list(filter="location eq '%s'" % region)
)
sku_names: List[str] = []
for sku in skus:
available = True
if sku.restrictions is not None:
for restriction in sku.restrictions:
if restriction.type == ResourceSkuRestrictionsType.location and (
region.upper() in [v.upper() for v in restriction.values]
):
available = False
break
if available:
sku_names.append(sku.name)
return sku_names

View File

@ -1,34 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Optional, Tuple
from onefuzztypes.events import EventInstanceConfigUpdated
from onefuzztypes.models import InstanceConfig as BASE_CONFIG
from pydantic import Field
from .azure.creds import get_instance_name
from .events import send_event
from .orm import ORMMixin
class InstanceConfig(BASE_CONFIG, ORMMixin):
instance_name: str = Field(default_factory=get_instance_name)
@classmethod
def key_fields(cls) -> Tuple[str, Optional[str]]:
return ("instance_name", None)
@classmethod
def fetch(cls) -> "InstanceConfig":
entry = cls.get(get_instance_name())
if entry is None:
entry = cls(allowed_aad_tenants=[])
entry.save()
return entry
def save(self, new: bool = False, require_etag: bool = False) -> None:
super().save(new=new, require_etag=require_etag)
send_event(EventInstanceConfigUpdated(config=self))

View File

@ -1,204 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import urllib
from typing import Callable, Optional
from uuid import UUID
import azure.functions as func
from azure.functions import HttpRequest
from memoization import cached
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error, UserInfo
from .azure.creds import get_scaleset_principal_id
from .azure.group_membership import create_group_membership_checker
from .config import InstanceConfig
from .request import not_ok
from .request_access import RequestAccess
from .user_credentials import parse_jwt_token
from .workers.pools import Pool
from .workers.scalesets import Scaleset
@cached(ttl=60)
def get_rules() -> Optional[RequestAccess]:
config = InstanceConfig.fetch()
if config.api_access_rules:
return RequestAccess.build(config.api_access_rules)
else:
return None
def check_access(req: HttpRequest) -> Optional[Error]:
rules = get_rules()
# Nothing to enforce if there are no rules.
if not rules:
return None
path = urllib.parse.urlparse(req.url).path
rule = rules.get_matching_rules(req.method, path)
# No restriction defined on this endpoint.
if not rule:
return None
member_id = UUID(req.headers["x-ms-client-principal-id"])
try:
membership_checker = create_group_membership_checker()
allowed = membership_checker.is_member(rule.allowed_groups_ids, member_id)
if not allowed:
logging.error(
"unauthorized access: %s is not authorized to access in %s",
member_id,
req.url,
)
return Error(
code=ErrorCode.UNAUTHORIZED,
errors=["not approved to use this endpoint"],
)
except Exception as e:
return Error(
code=ErrorCode.UNAUTHORIZED,
errors=["unable to interact with graph", str(e)],
)
return None
@cached(ttl=60)
def is_agent(token_data: UserInfo) -> bool:
if token_data.object_id:
# backward compatibility case for scalesets deployed before the migration
# to user assigned managed id
scalesets = Scaleset.get_by_object_id(token_data.object_id)
if len(scalesets) > 0:
return True
# verify object_id against the user assigned managed identity
principal_id: UUID = get_scaleset_principal_id()
return principal_id == token_data.object_id
if not token_data.application_id:
return False
pools = Pool.search(query={"client_id": [token_data.application_id]})
if len(pools) > 0:
return True
return False
def can_modify_config_impl(config: InstanceConfig, user_info: UserInfo) -> bool:
if config.admins is None:
return True
return user_info.object_id in config.admins
def can_modify_config(req: func.HttpRequest, config: InstanceConfig) -> bool:
user_info = parse_jwt_token(req)
if not isinstance(user_info, UserInfo):
return False
return can_modify_config_impl(config, user_info)
def check_require_admins_impl(
config: InstanceConfig, user_info: UserInfo
) -> Optional[Error]:
if not config.require_admin_privileges:
return None
if config.admins is None:
return Error(code=ErrorCode.UNAUTHORIZED, errors=["pool modification disabled"])
if user_info.object_id in config.admins:
return None
return Error(code=ErrorCode.UNAUTHORIZED, errors=["not authorized to manage pools"])
def check_require_admins(req: func.HttpRequest) -> Optional[Error]:
user_info = parse_jwt_token(req)
if isinstance(user_info, Error):
return user_info
# When there are no admins in the `admins` list, all users are considered
# admins. However, `require_admin_privileges` is still useful to protect from
# mistakes.
#
# To make changes while still protecting against accidental changes to
# pools, do the following:
#
# 1. set `require_admin_privileges` to `False`
# 2. make the change
# 3. set `require_admin_privileges` to `True`
config = InstanceConfig.fetch()
return check_require_admins_impl(config, user_info)
def is_user(token_data: UserInfo) -> bool:
return not is_agent(token_data)
def reject(req: func.HttpRequest, token: UserInfo) -> func.HttpResponse:
logging.error(
"reject token. url:%s token:%s body:%s",
repr(req.url),
repr(token),
repr(req.get_body()),
)
return not_ok(
Error(code=ErrorCode.UNAUTHORIZED, errors=["Unrecognized agent"]),
status_code=401,
context="token verification",
)
def call_if(
req: func.HttpRequest,
method: Callable[[func.HttpRequest], func.HttpResponse],
*,
allow_user: bool = False,
allow_agent: bool = False
) -> func.HttpResponse:
token = parse_jwt_token(req)
if isinstance(token, Error):
return not_ok(token, status_code=401, context="token verification")
if is_user(token):
if not allow_user:
return reject(req, token)
access = check_access(req)
if isinstance(access, Error):
return not_ok(access, status_code=401, context="access control")
if is_agent(token) and not allow_agent:
return reject(req, token)
return method(req)
def call_if_user(
req: func.HttpRequest, method: Callable[[func.HttpRequest], func.HttpResponse]
) -> func.HttpResponse:
return call_if(req, method, allow_user=True)
def call_if_agent(
req: func.HttpRequest, method: Callable[[func.HttpRequest], func.HttpResponse]
) -> func.HttpResponse:
return call_if(req, method, allow_agent=True)

View File

@ -1,81 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from typing import List
from onefuzztypes.events import Event, EventMessage, EventType, get_event_type
from onefuzztypes.models import UserInfo
from pydantic import BaseModel
from .azure.creds import get_instance_id, get_instance_name
from .azure.queue import send_message
from .azure.storage import StorageType
from .webhooks import Webhook
class SignalREvent(BaseModel):
target: str
arguments: List[EventMessage]
def queue_signalr_event(event_message: EventMessage) -> None:
message = SignalREvent(target="events", arguments=[event_message]).json().encode()
send_message("signalr-events", message, StorageType.config)
def log_event(event: Event, event_type: EventType) -> None:
scrubbed_event = filter_event(event)
logging.info(
"sending event: %s - %s", event_type, scrubbed_event.json(exclude_none=True)
)
def filter_event(event: Event) -> BaseModel:
clone_event = event.copy(deep=True)
filtered_event = filter_event_recurse(clone_event)
return filtered_event
def filter_event_recurse(entry: BaseModel) -> BaseModel:
for field in entry.__fields__:
field_data = getattr(entry, field)
if isinstance(field_data, UserInfo):
field_data = None
elif isinstance(field_data, list):
for (i, value) in enumerate(field_data):
if isinstance(value, BaseModel):
field_data[i] = filter_event_recurse(value)
elif isinstance(field_data, dict):
for (key, value) in field_data.items():
if isinstance(value, BaseModel):
field_data[key] = filter_event_recurse(value)
elif isinstance(field_data, BaseModel):
field_data = filter_event_recurse(field_data)
setattr(entry, field, field_data)
return entry
def send_event(event: Event) -> None:
event_type = get_event_type(event)
event_message = EventMessage(
event_type=event_type,
event=event.copy(deep=True),
instance_id=get_instance_id(),
instance_name=get_instance_name(),
)
# work around odd bug with Event Message creation. See PR 939
if event_message.event != event:
event_message.event = event.copy(deep=True)
log_event(event, event_type)
queue_signalr_event(event_message)
Webhook.send_event(event_message)

View File

@ -1,515 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from typing import List, Optional
from uuid import UUID, uuid4
from onefuzztypes.enums import OS, AgentMode
from onefuzztypes.models import (
AgentConfig,
AzureMonitorExtensionConfig,
KeyvaultExtensionConfig,
Pool,
ReproConfig,
Scaleset,
)
from onefuzztypes.primitives import Container, Extension, Region
from .azure.containers import (
get_container_sas_url,
get_file_sas_url,
get_file_url,
save_blob,
)
from .azure.creds import get_agent_instance_url, get_instance_id
from .azure.log_analytics import get_monitor_settings
from .azure.queue import get_queue_sas
from .azure.storage import StorageType
from .config import InstanceConfig
from .reports import get_report
def generic_extensions(region: Region, vm_os: OS) -> List[Extension]:
instance_config = InstanceConfig.fetch()
extensions = [monitor_extension(region, vm_os)]
dependency = dependency_extension(region, vm_os)
if dependency:
extensions.append(dependency)
if instance_config.extensions:
if instance_config.extensions.keyvault:
keyvault = keyvault_extension(
region, instance_config.extensions.keyvault, vm_os
)
extensions.append(keyvault)
if instance_config.extensions.geneva and vm_os == OS.windows:
geneva = geneva_extension(region)
extensions.append(geneva)
if instance_config.extensions.azure_monitor and vm_os == OS.linux:
azmon = azmon_extension(region, instance_config.extensions.azure_monitor)
extensions.append(azmon)
if instance_config.extensions.azure_security and vm_os == OS.linux:
azsec = azsec_extension(region)
extensions.append(azsec)
return extensions
def monitor_extension(region: Region, vm_os: OS) -> Extension:
settings = get_monitor_settings()
if vm_os == OS.windows:
return {
"name": "OMSExtension",
"publisher": "Microsoft.EnterpriseCloud.Monitoring",
"type": "MicrosoftMonitoringAgent",
"typeHandlerVersion": "1.0",
"location": region,
"autoUpgradeMinorVersion": True,
"settings": {"workspaceId": settings["id"]},
"protectedSettings": {"workspaceKey": settings["key"]},
}
elif vm_os == OS.linux:
return {
"name": "OMSExtension",
"publisher": "Microsoft.EnterpriseCloud.Monitoring",
"type": "OmsAgentForLinux",
"typeHandlerVersion": "1.12",
"location": region,
"autoUpgradeMinorVersion": True,
"settings": {"workspaceId": settings["id"]},
"protectedSettings": {"workspaceKey": settings["key"]},
}
raise NotImplementedError("unsupported os: %s" % vm_os)
def dependency_extension(region: Region, vm_os: OS) -> Optional[Extension]:
if vm_os == OS.windows:
extension = {
"name": "DependencyAgentWindows",
"publisher": "Microsoft.Azure.Monitoring.DependencyAgent",
"type": "DependencyAgentWindows",
"typeHandlerVersion": "9.5",
"location": region,
"autoUpgradeMinorVersion": True,
}
return extension
else:
# TODO: dependency agent for linux is not reliable
# extension = {
# "name": "DependencyAgentLinux",
# "publisher": "Microsoft.Azure.Monitoring.DependencyAgent",
# "type": "DependencyAgentLinux",
# "typeHandlerVersion": "9.5",
# "location": vm.region,
# "autoUpgradeMinorVersion": True,
# }
return None
def geneva_extension(region: Region) -> Extension:
return {
"name": "Microsoft.Azure.Geneva.GenevaMonitoring",
"publisher": "Microsoft.Azure.Geneva",
"type": "GenevaMonitoring",
"typeHandlerVersion": "2.0",
"location": region,
"autoUpgradeMinorVersion": True,
"enableAutomaticUpgrade": True,
"settings": {},
"protectedSettings": {},
}
def azmon_extension(
region: Region, azure_monitor: AzureMonitorExtensionConfig
) -> Extension:
auth_id = azure_monitor.monitoringGCSAuthId
config_version = azure_monitor.config_version
moniker = azure_monitor.moniker
namespace = azure_monitor.namespace
environment = azure_monitor.monitoringGSEnvironment
account = azure_monitor.monitoringGCSAccount
auth_id_type = azure_monitor.monitoringGCSAuthIdType
return {
"name": "AzureMonitorLinuxAgent",
"publisher": "Microsoft.Azure.Monitor",
"location": region,
"type": "AzureMonitorLinuxAgent",
"typeHandlerVersion": "1.0",
"autoUpgradeMinorVersion": True,
"settings": {"GCS_AUTO_CONFIG": True},
"protectedsettings": {
"configVersion": config_version,
"moniker": moniker,
"namespace": namespace,
"monitoringGCSEnvironment": environment,
"monitoringGCSAccount": account,
"monitoringGCSRegion": region,
"monitoringGCSAuthId": auth_id,
"monitoringGCSAuthIdType": auth_id_type,
},
}
def azsec_extension(region: Region) -> Extension:
return {
"name": "AzureSecurityLinuxAgent",
"publisher": "Microsoft.Azure.Security.Monitoring",
"location": region,
"type": "AzureSecurityLinuxAgent",
"typeHandlerVersion": "2.0",
"autoUpgradeMinorVersion": True,
"settings": {"enableGenevaUpload": True, "enableAutoConfig": True},
}
def keyvault_extension(
region: Region, keyvault: KeyvaultExtensionConfig, vm_os: OS
) -> Extension:
keyvault_name = keyvault.keyvault_name
cert_name = keyvault.cert_name
uri = keyvault_name + cert_name
if vm_os == OS.windows:
return {
"name": "KVVMExtensionForWindows",
"location": region,
"publisher": "Microsoft.Azure.KeyVault",
"type": "KeyVaultForWindows",
"typeHandlerVersion": "1.0",
"autoUpgradeMinorVersion": True,
"settings": {
"secretsManagementSettings": {
"pollingIntervalInS": "3600",
"certificateStoreName": "MY",
"linkOnRenewal": False,
"certificateStoreLocation": "LocalMachine",
"requireInitialSync": True,
"observedCertificates": [uri],
}
},
}
elif vm_os == OS.linux:
cert_path = keyvault.cert_path
extension_store = keyvault.extension_store
cert_location = cert_path + extension_store
return {
"name": "KVVMExtensionForLinux",
"location": region,
"publisher": "Microsoft.Azure.KeyVault",
"type": "KeyVaultForLinux",
"typeHandlerVersion": "2.0",
"autoUpgradeMinorVersion": True,
"settings": {
"secretsManagementSettings": {
"pollingIntervalInS": "3600",
"certificateStoreLocation": cert_location,
"observedCertificates": [uri],
},
},
}
raise NotImplementedError("unsupported os: %s" % vm_os)
def build_scaleset_script(pool: Pool, scaleset: Scaleset) -> str:
commands = []
extension = "ps1" if pool.os == OS.windows else "sh"
filename = f"{scaleset.scaleset_id}/scaleset-setup.{extension}"
sep = "\r\n" if pool.os == OS.windows else "\n"
if pool.os == OS.windows and scaleset.auth is not None:
ssh_key = scaleset.auth.public_key.strip()
ssh_path = "$env:ProgramData/ssh/administrators_authorized_keys"
commands += [f'Set-Content -Path {ssh_path} -Value "{ssh_key}"']
save_blob(
Container("vm-scripts"), filename, sep.join(commands) + sep, StorageType.config
)
return get_file_url(Container("vm-scripts"), filename, StorageType.config)
def build_pool_config(pool: Pool) -> str:
config = AgentConfig(
pool_name=pool.name,
onefuzz_url=get_agent_instance_url(),
heartbeat_queue=get_queue_sas(
"node-heartbeat",
StorageType.config,
add=True,
),
instance_telemetry_key=os.environ.get("APPINSIGHTS_INSTRUMENTATIONKEY"),
microsoft_telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"),
instance_id=get_instance_id(),
)
multi_tenant_domain = os.environ.get("MULTI_TENANT_DOMAIN")
if multi_tenant_domain:
config.multi_tenant_domain = multi_tenant_domain
filename = f"{pool.name}/config.json"
save_blob(
Container("vm-scripts"),
filename,
config.json(),
StorageType.config,
)
return config_url(Container("vm-scripts"), filename, False)
def update_managed_scripts() -> None:
commands = [
"azcopy sync '%s' instance-specific-setup"
% (
get_container_sas_url(
Container("instance-specific-setup"),
StorageType.config,
read=True,
list_=True,
)
),
"azcopy sync '%s' tools"
% (
get_container_sas_url(
Container("tools"), StorageType.config, read=True, list_=True
)
),
]
save_blob(
Container("vm-scripts"),
"managed.ps1",
"\r\n".join(commands) + "\r\n",
StorageType.config,
)
save_blob(
Container("vm-scripts"),
"managed.sh",
"\n".join(commands) + "\n",
StorageType.config,
)
def config_url(container: Container, filename: str, with_sas: bool) -> str:
if with_sas:
return get_file_sas_url(container, filename, StorageType.config, read=True)
else:
return get_file_url(container, filename, StorageType.config)
def agent_config(
region: Region,
vm_os: OS,
mode: AgentMode,
*,
urls: Optional[List[str]] = None,
with_sas: bool = False,
) -> Extension:
update_managed_scripts()
if urls is None:
urls = []
if vm_os == OS.windows:
urls += [
config_url(Container("vm-scripts"), "managed.ps1", with_sas),
config_url(Container("tools"), "win64/azcopy.exe", with_sas),
config_url(
Container("tools"),
"win64/setup.ps1",
with_sas,
),
config_url(
Container("tools"),
"win64/onefuzz.ps1",
with_sas,
),
]
to_execute_cmd = (
"powershell -ExecutionPolicy Unrestricted -File win64/setup.ps1 "
"-mode %s" % (mode.name)
)
extension = {
"name": "CustomScriptExtension",
"type": "CustomScriptExtension",
"publisher": "Microsoft.Compute",
"location": region,
"force_update_tag": uuid4(),
"type_handler_version": "1.9",
"auto_upgrade_minor_version": True,
"settings": {
"commandToExecute": to_execute_cmd,
"fileUris": urls,
},
"protectedSettings": {
"managedIdentity": {},
},
}
return extension
elif vm_os == OS.linux:
urls += [
config_url(
Container("vm-scripts"),
"managed.sh",
with_sas,
),
config_url(
Container("tools"),
"linux/azcopy",
with_sas,
),
config_url(
Container("tools"),
"linux/setup.sh",
with_sas,
),
]
to_execute_cmd = "bash setup.sh %s" % (mode.name)
extension = {
"name": "CustomScript",
"publisher": "Microsoft.Azure.Extensions",
"type": "CustomScript",
"typeHandlerVersion": "2.1",
"location": region,
"force_update_tag": uuid4(),
"autoUpgradeMinorVersion": True,
"settings": {
"commandToExecute": to_execute_cmd,
"fileUris": urls,
},
"protectedSettings": {
"managedIdentity": {},
},
}
return extension
raise NotImplementedError("unsupported OS: %s" % vm_os)
def fuzz_extensions(pool: Pool, scaleset: Scaleset) -> List[Extension]:
urls = [build_pool_config(pool), build_scaleset_script(pool, scaleset)]
fuzz_extension = agent_config(scaleset.region, pool.os, AgentMode.fuzz, urls=urls)
extensions = generic_extensions(scaleset.region, pool.os)
extensions += [fuzz_extension]
return extensions
def repro_extensions(
region: Region,
repro_os: OS,
repro_id: UUID,
repro_config: ReproConfig,
setup_container: Optional[Container],
) -> List[Extension]:
# TODO - what about contents of repro.ps1 / repro.sh?
report = get_report(repro_config.container, repro_config.path)
if report is None:
raise Exception("invalid report: %s" % repro_config)
if report.input_blob is None:
raise Exception("unable to perform reproduction without an input blob")
commands = []
if setup_container:
commands += [
"azcopy sync '%s' ./setup"
% (
get_container_sas_url(
setup_container, StorageType.corpus, read=True, list_=True
)
),
]
urls = [
get_file_sas_url(
repro_config.container, repro_config.path, StorageType.corpus, read=True
),
get_file_sas_url(
report.input_blob.container,
report.input_blob.name,
StorageType.corpus,
read=True,
),
]
repro_files = []
if repro_os == OS.windows:
repro_files = ["%s/repro.ps1" % repro_id]
task_script = "\r\n".join(commands)
script_name = "task-setup.ps1"
else:
repro_files = ["%s/repro.sh" % repro_id, "%s/repro-stdout.sh" % repro_id]
commands += ["chmod -R +x setup"]
task_script = "\n".join(commands)
script_name = "task-setup.sh"
save_blob(
Container("task-configs"),
"%s/%s" % (repro_id, script_name),
task_script,
StorageType.config,
)
for repro_file in repro_files:
urls += [
get_file_sas_url(
Container("repro-scripts"),
repro_file,
StorageType.config,
read=True,
),
get_file_sas_url(
Container("task-configs"),
"%s/%s" % (repro_id, script_name),
StorageType.config,
read=True,
),
]
base_extension = agent_config(
region, repro_os, AgentMode.repro, urls=urls, with_sas=True
)
extensions = generic_extensions(region, repro_os)
extensions += [base_extension]
return extensions
def proxy_manager_extensions(region: Region, proxy_id: UUID) -> List[Extension]:
urls = [
get_file_sas_url(
Container("proxy-configs"),
"%s/%s/config.json" % (region, proxy_id),
StorageType.config,
read=True,
),
get_file_sas_url(
Container("tools"),
"linux/onefuzz-proxy-manager",
StorageType.config,
read=True,
),
]
base_extension = agent_config(
region, OS.linux, AgentMode.proxy, urls=urls, with_sas=True
)
extensions = generic_extensions(region, OS.linux)
extensions += [base_extension]
return extensions

View File

@ -1,14 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .afl import afl_linux, afl_windows
from .libfuzzer import libfuzzer_linux, libfuzzer_windows
TEMPLATES = {
"afl_windows": afl_windows,
"afl_linux": afl_linux,
"libfuzzer_linux": libfuzzer_linux,
"libfuzzer_windows": libfuzzer_windows,
}

View File

@ -1,271 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from uuid import UUID
from onefuzztypes.enums import (
OS,
ContainerType,
TaskType,
UserFieldOperation,
UserFieldType,
)
from onefuzztypes.job_templates import JobTemplate, UserField, UserFieldLocation
from onefuzztypes.models import (
JobConfig,
TaskConfig,
TaskContainers,
TaskDetails,
TaskPool,
)
from onefuzztypes.primitives import Container, PoolName
from .common import (
DURATION_HELP,
POOL_HELP,
REBOOT_HELP,
RETRY_COUNT_HELP,
TAGS_HELP,
TARGET_EXE_HELP,
TARGET_OPTIONS_HELP,
VM_COUNT_HELP,
)
afl_linux = JobTemplate(
os=OS.linux,
job=JobConfig(project="", name=Container(""), build="", duration=1),
tasks=[
TaskConfig(
job_id=(UUID(int=0)),
task=TaskDetails(
type=TaskType.generic_supervisor,
duration=1,
target_exe="fuzz.exe",
target_env={},
target_options=[],
supervisor_exe="",
supervisor_options=[],
supervisor_input_marker="@@",
),
pool=TaskPool(count=1, pool_name=PoolName("")),
containers=[
TaskContainers(
name=Container("afl-container-name"), type=ContainerType.tools
),
TaskContainers(name=Container(""), type=ContainerType.setup),
TaskContainers(name=Container(""), type=ContainerType.crashes),
TaskContainers(name=Container(""), type=ContainerType.inputs),
],
tags={},
),
TaskConfig(
job_id=UUID(int=0),
prereq_tasks=[UUID(int=0)],
task=TaskDetails(
type=TaskType.generic_crash_report,
duration=1,
target_exe="fuzz.exe",
target_env={},
target_options=[],
check_debugger=True,
),
pool=TaskPool(count=1, pool_name=PoolName("")),
containers=[
TaskContainers(name=Container(""), type=ContainerType.setup),
TaskContainers(name=Container(""), type=ContainerType.crashes),
TaskContainers(name=Container(""), type=ContainerType.no_repro),
TaskContainers(name=Container(""), type=ContainerType.reports),
TaskContainers(name=Container(""), type=ContainerType.unique_reports),
],
tags={},
),
],
notifications=[],
user_fields=[
UserField(
name="pool_name",
help=POOL_HELP,
type=UserFieldType.Str,
required=True,
locations=[
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/0/pool/pool_name",
),
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/1/pool/pool_name",
),
],
),
UserField(
name="duration",
help=DURATION_HELP,
type=UserFieldType.Int,
default=24,
locations=[
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/0/task/duration",
),
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/1/task/duration",
),
UserFieldLocation(op=UserFieldOperation.replace, path="/job/duration"),
],
),
UserField(
name="target_exe",
help=TARGET_EXE_HELP,
type=UserFieldType.Str,
default="fuzz.exe",
locations=[
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/0/task/target_exe",
),
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/1/task/target_exe",
),
],
),
UserField(
name="target_options",
help=TARGET_OPTIONS_HELP,
type=UserFieldType.ListStr,
locations=[
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/0/task/target_options",
),
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/1/task/target_options",
),
],
),
UserField(
name="supervisor_exe",
help="Path to the AFL executable",
type=UserFieldType.Str,
default="{tools_dir}/afl-fuzz",
locations=[
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/0/task/supervisor_exe",
),
],
),
UserField(
name="supervisor_options",
help="AFL command line options",
type=UserFieldType.ListStr,
default=[
"-d",
"-i",
"{input_corpus}",
"-o",
"{runtime_dir}",
"--",
"{target_exe}",
"{target_options}",
],
locations=[
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/0/task/supervisor_options",
),
],
),
UserField(
name="supervisor_env",
help="Enviornment variables for AFL",
type=UserFieldType.DictStr,
locations=[
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/0/task/supervisor_env",
),
],
),
UserField(
name="vm_count",
help=VM_COUNT_HELP,
type=UserFieldType.Int,
default=2,
locations=[
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/0/pool/count",
),
],
),
UserField(
name="check_retry_count",
help=RETRY_COUNT_HELP,
type=UserFieldType.Int,
locations=[
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/1/task/check_retry_count",
),
],
),
UserField(
name="afl_container",
help=(
"Name of the AFL storage container (use "
"this to specify alternate builds of AFL)"
),
type=UserFieldType.Str,
default="afl-linux",
locations=[
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/0/containers/0/name",
),
],
),
UserField(
name="reboot_after_setup",
help=REBOOT_HELP,
type=UserFieldType.Bool,
default=False,
locations=[
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/0/task/reboot_after_setup",
),
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/1/task/reboot_after_setup",
),
],
),
UserField(
name="tags",
help=TAGS_HELP,
type=UserFieldType.DictStr,
locations=[
UserFieldLocation(
op=UserFieldOperation.add,
path="/tasks/0/tags",
),
UserFieldLocation(
op=UserFieldOperation.add,
path="/tasks/1/tags",
),
],
),
],
)
afl_windows = afl_linux.copy(deep=True)
afl_windows.os = OS.windows
for user_field in afl_windows.user_fields:
if user_field.name == "afl_container":
user_field.default = "afl-windows"

View File

@ -1,14 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
POOL_HELP = "Execute the task on the specified pool"
DURATION_HELP = "Number of hours to execute the task"
TARGET_EXE_HELP = "Path to the target executable"
TARGET_OPTIONS_HELP = "Command line options for the target"
VM_COUNT_HELP = "Number of VMs to use for fuzzing"
RETRY_COUNT_HELP = "Number of times to retry a crash to verify reproducability"
REBOOT_HELP = "After executing the setup script, reboot the VM"
TAGS_HELP = "User provided metadata for the tasks"

View File

@ -1,348 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from uuid import UUID
from onefuzztypes.enums import (
OS,
ContainerType,
TaskType,
UserFieldOperation,
UserFieldType,
)
from onefuzztypes.job_templates import JobTemplate, UserField, UserFieldLocation
from onefuzztypes.models import (
JobConfig,
TaskConfig,
TaskContainers,
TaskDetails,
TaskPool,
)
from onefuzztypes.primitives import Container, PoolName
from .common import (
DURATION_HELP,
POOL_HELP,
REBOOT_HELP,
RETRY_COUNT_HELP,
TAGS_HELP,
TARGET_EXE_HELP,
TARGET_OPTIONS_HELP,
VM_COUNT_HELP,
)
libfuzzer_linux = JobTemplate(
os=OS.linux,
job=JobConfig(project="", name=Container(""), build="", duration=1),
tasks=[
TaskConfig(
job_id=UUID(int=0),
task=TaskDetails(
type=TaskType.libfuzzer_fuzz,
duration=1,
target_exe="fuzz.exe",
target_env={},
target_options=[],
),
pool=TaskPool(count=1, pool_name=PoolName("")),
containers=[
TaskContainers(name=Container(""), type=ContainerType.setup),
TaskContainers(name=Container(""), type=ContainerType.crashes),
TaskContainers(name=Container(""), type=ContainerType.inputs),
],
tags={},
),
TaskConfig(
job_id=UUID(int=0),
prereq_tasks=[UUID(int=0)],
task=TaskDetails(
type=TaskType.libfuzzer_crash_report,
duration=1,
target_exe="fuzz.exe",
target_env={},
target_options=[],
),
pool=TaskPool(count=1, pool_name=PoolName("")),
containers=[
TaskContainers(name=Container(""), type=ContainerType.setup),
TaskContainers(name=Container(""), type=ContainerType.crashes),
TaskContainers(name=Container(""), type=ContainerType.no_repro),
TaskContainers(name=Container(""), type=ContainerType.reports),
TaskContainers(name=Container(""), type=ContainerType.unique_reports),
],
tags={},
),
TaskConfig(
job_id=UUID(int=0),
prereq_tasks=[UUID(int=0)],
task=TaskDetails(
type=TaskType.coverage,
duration=1,
target_exe="fuzz.exe",
target_env={},
target_options=[],
),
pool=TaskPool(count=1, pool_name=PoolName("")),
containers=[
TaskContainers(name=Container(""), type=ContainerType.setup),
TaskContainers(name=Container(""), type=ContainerType.readonly_inputs),
TaskContainers(name=Container(""), type=ContainerType.coverage),
],
tags={},
),
],
notifications=[],
user_fields=[
UserField(
name="pool_name",
help=POOL_HELP,
type=UserFieldType.Str,
required=True,
locations=[
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/0/pool/pool_name",
),
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/1/pool/pool_name",
),
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/2/pool/pool_name",
),
],
),
UserField(
name="target_exe",
help=TARGET_EXE_HELP,
type=UserFieldType.Str,
default="fuzz.exe",
locations=[
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/0/task/target_exe",
),
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/1/task/target_exe",
),
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/2/task/target_exe",
),
],
),
UserField(
name="duration",
help=DURATION_HELP,
type=UserFieldType.Int,
default=24,
locations=[
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/0/task/duration",
),
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/1/task/duration",
),
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/2/task/duration",
),
UserFieldLocation(op=UserFieldOperation.replace, path="/job/duration"),
],
),
UserField(
name="target_workers",
help="Number of instances of the libfuzzer target on each VM",
type=UserFieldType.Int,
locations=[
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/0/task/target_workers",
),
],
),
UserField(
name="vm_count",
help=VM_COUNT_HELP,
type=UserFieldType.Int,
default=2,
locations=[
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/0/pool/count",
),
],
),
UserField(
name="target_options",
help=TARGET_OPTIONS_HELP,
type=UserFieldType.ListStr,
locations=[
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/0/task/target_options",
),
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/1/task/target_options",
),
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/2/task/target_options",
),
],
),
UserField(
name="target_env",
help="Environment variables for the target",
type=UserFieldType.DictStr,
locations=[
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/0/task/target_env",
),
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/1/task/target_env",
),
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/2/task/target_env",
),
],
),
UserField(
name="check_fuzzer_help",
help="Verify fuzzer by checking if it supports -help=1",
type=UserFieldType.Bool,
default=True,
locations=[
UserFieldLocation(
op=UserFieldOperation.add,
path="/tasks/0/task/check_fuzzer_help",
),
UserFieldLocation(
op=UserFieldOperation.add,
path="/tasks/1/task/check_fuzzer_help",
),
UserFieldLocation(
op=UserFieldOperation.add,
path="/tasks/2/task/check_fuzzer_help",
),
],
),
UserField(
name="colocate",
help="Run all of the tasks on the same node",
type=UserFieldType.Bool,
default=True,
locations=[
UserFieldLocation(
op=UserFieldOperation.add,
path="/tasks/0/colocate",
),
UserFieldLocation(
op=UserFieldOperation.add,
path="/tasks/1/colocate",
),
UserFieldLocation(
op=UserFieldOperation.add,
path="/tasks/2/colocate",
),
],
),
UserField(
name="expect_crash_on_failure",
help="Require crashes upon non-zero exits from libfuzzer",
type=UserFieldType.Bool,
default=False,
locations=[
UserFieldLocation(
op=UserFieldOperation.add,
path="/tasks/0/task/expect_crash_on_failure",
),
],
),
UserField(
name="reboot_after_setup",
help=REBOOT_HELP,
type=UserFieldType.Bool,
default=False,
locations=[
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/0/task/reboot_after_setup",
),
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/1/task/reboot_after_setup",
),
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/2/task/reboot_after_setup",
),
],
),
UserField(
name="check_retry_count",
help=RETRY_COUNT_HELP,
type=UserFieldType.Int,
locations=[
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/1/task/check_retry_count",
),
],
),
UserField(
name="target_timeout",
help="Number of seconds to timeout during reproduction",
type=UserFieldType.Int,
locations=[
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/1/task/target_timeout",
),
],
),
UserField(
name="minimized_stack_depth",
help="Number of frames to include in the minimized stack",
type=UserFieldType.Int,
locations=[
UserFieldLocation(
op=UserFieldOperation.replace,
path="/tasks/1/task/minimized_stack_depth",
),
],
),
UserField(
name="tags",
help=TAGS_HELP,
type=UserFieldType.DictStr,
locations=[
UserFieldLocation(
op=UserFieldOperation.add,
path="/tasks/0/tags",
),
UserFieldLocation(
op=UserFieldOperation.add,
path="/tasks/1/tags",
),
UserFieldLocation(
op=UserFieldOperation.add,
path="/tasks/2/tags",
),
],
),
],
)
libfuzzer_windows = libfuzzer_linux.copy(deep=True)
libfuzzer_windows.os = OS.windows

View File

@ -1,131 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import json
from typing import Dict, List
from jsonpatch import apply_patch
from memoization import cached
from onefuzztypes.enums import ContainerType, ErrorCode, UserFieldType
from onefuzztypes.job_templates import (
TEMPLATE_BASE_FIELDS,
JobTemplate,
JobTemplateConfig,
JobTemplateField,
JobTemplateRequest,
TemplateUserData,
UserField,
)
from onefuzztypes.models import Error, Result
def template_container_types(template: JobTemplate) -> List[ContainerType]:
return list(set(c.type for t in template.tasks for c in t.containers if not c.name))
@cached
def build_input_config(name: str, template: JobTemplate) -> JobTemplateConfig:
user_fields = [
JobTemplateField(
name=x.name,
type=x.type,
required=x.required,
default=x.default,
help=x.help,
)
for x in TEMPLATE_BASE_FIELDS + template.user_fields
]
containers = template_container_types(template)
return JobTemplateConfig(
os=template.os,
name=name,
user_fields=user_fields,
containers=containers,
)
def build_patches(
data: TemplateUserData, field: UserField
) -> List[Dict[str, TemplateUserData]]:
patches = []
if field.type == UserFieldType.Bool and not isinstance(data, bool):
raise Exception("invalid bool field")
if field.type == UserFieldType.Int and not isinstance(data, int):
raise Exception("invalid int field")
if field.type == UserFieldType.Str and not isinstance(data, str):
raise Exception("invalid str field")
if field.type == UserFieldType.DictStr and not isinstance(data, dict):
raise Exception("invalid DictStr field")
if field.type == UserFieldType.ListStr and not isinstance(data, list):
raise Exception("invalid ListStr field")
for location in field.locations:
patches.append(
{
"op": location.op.name,
"path": location.path,
"value": data,
}
)
return patches
def _fail(why: str) -> Error:
return Error(code=ErrorCode.INVALID_REQUEST, errors=[why])
def render(request: JobTemplateRequest, template: JobTemplate) -> Result[JobTemplate]:
patches = []
seen = set()
for name in request.user_fields:
for field in TEMPLATE_BASE_FIELDS + template.user_fields:
if field.name == name:
if name in seen:
return _fail(f"duplicate specification: {name}")
seen.add(name)
if name not in seen:
return _fail(f"extra field: {name}")
for field in TEMPLATE_BASE_FIELDS + template.user_fields:
if field.name not in request.user_fields:
if field.required:
return _fail(f"missing required field: {field.name}")
else:
# optional fields can be missing
continue
patches += build_patches(request.user_fields[field.name], field)
raw = json.loads(template.json())
updated = apply_patch(raw, patches)
rendered = JobTemplate.parse_obj(updated)
used_containers = []
for task in rendered.tasks:
for task_container in task.containers:
if task_container.name:
# only need to fill out containers with names
continue
for entry in request.containers:
if entry.type != task_container.type:
continue
task_container.name = entry.name
used_containers.append(entry)
if not task_container.name:
return _fail(f"missing container definition {task_container.type}")
for entry in request.containers:
if entry not in used_containers:
return _fail(f"unused container in request: {entry}")
return rendered

View File

@ -1,118 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import List, Optional, Tuple
from onefuzztypes.enums import ErrorCode
from onefuzztypes.job_templates import JobTemplateConfig
from onefuzztypes.job_templates import JobTemplateIndex as BASE_INDEX
from onefuzztypes.job_templates import JobTemplateRequest
from onefuzztypes.models import Error, Result, UserInfo
from ..jobs import Job
from ..notifications.main import Notification
from ..orm import ORMMixin
from ..tasks.config import TaskConfigError, check_config
from ..tasks.main import Task
from .defaults import TEMPLATES
from .render import build_input_config, render
class JobTemplateIndex(BASE_INDEX, ORMMixin):
@classmethod
def key_fields(cls) -> Tuple[str, Optional[str]]:
return ("name", None)
@classmethod
def get_base_entry(cls, name: str) -> Optional[BASE_INDEX]:
result = cls.get(name)
if result is not None:
return BASE_INDEX(name=name, template=result.template)
template = TEMPLATES.get(name)
if template is None:
return None
return BASE_INDEX(name=name, template=template)
@classmethod
def get_index(cls) -> List[BASE_INDEX]:
entries = [BASE_INDEX(name=x.name, template=x.template) for x in cls.search()]
# if the local install has replaced the built-in templates, skip over them
for name, template in TEMPLATES.items():
if any(x.name == name for x in entries):
continue
entries.append(BASE_INDEX(name=name, template=template))
return entries
@classmethod
def get_configs(cls) -> List[JobTemplateConfig]:
configs = [build_input_config(x.name, x.template) for x in cls.get_index()]
return configs
@classmethod
def execute(cls, request: JobTemplateRequest, user_info: UserInfo) -> Result[Job]:
index = cls.get(request.name)
if index is None:
if request.name not in TEMPLATES:
return Error(
code=ErrorCode.INVALID_REQUEST,
errors=["no such template: %s" % request.name],
)
base_template = TEMPLATES[request.name]
else:
base_template = index.template
template = render(request, base_template)
if isinstance(template, Error):
return template
try:
for task_config in template.tasks:
check_config(task_config)
if task_config.pool is None:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["pool not defined"]
)
except TaskConfigError as err:
return Error(code=ErrorCode.INVALID_REQUEST, errors=[str(err)])
for notification_config in template.notifications:
for task_container in request.containers:
if task_container.type == notification_config.container_type:
notification = Notification.create(
task_container.name,
notification_config.notification.config,
True,
)
if isinstance(notification, Error):
return notification
job = Job(config=template.job)
job.save()
tasks: List[Task] = []
for task_config in template.tasks:
task_config.job_id = job.job_id
if task_config.prereq_tasks:
# pydantic verifies prereq_tasks in u128 form are index refs to
# previously generated tasks
task_config.prereq_tasks = [
tasks[x.int].task_id for x in task_config.prereq_tasks
]
task = Task.create(
config=task_config, job_id=job.job_id, user_info=user_info
)
if isinstance(task, Error):
return task
tasks.append(task)
return job

View File

@ -1,144 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from datetime import datetime, timedelta
from typing import List, Optional, Tuple
from onefuzztypes.enums import ErrorCode, JobState, TaskState
from onefuzztypes.events import EventJobCreated, EventJobStopped, JobTaskStopped
from onefuzztypes.models import Error
from onefuzztypes.models import Job as BASE_JOB
from .events import send_event
from .orm import MappingIntStrAny, ORMMixin, QueryFilter
from .tasks.main import Task
JOB_LOG_PREFIX = "jobs: "
JOB_NEVER_STARTED_DURATION: timedelta = timedelta(days=30)
class Job(BASE_JOB, ORMMixin):
@classmethod
def key_fields(cls) -> Tuple[str, Optional[str]]:
return ("job_id", None)
@classmethod
def search_states(cls, *, states: Optional[List[JobState]] = None) -> List["Job"]:
query: QueryFilter = {}
if states:
query["state"] = states
return cls.search(query=query)
@classmethod
def search_expired(cls) -> List["Job"]:
time_filter = "end_time lt datetime'%s'" % datetime.utcnow().isoformat()
return cls.search(
query={"state": JobState.available()}, raw_unchecked_filter=time_filter
)
@classmethod
def stop_never_started_jobs(cls) -> None:
# Note, the "not(end_time...)" with end_time set long before the use of
# OneFuzz enables identifying those without end_time being set.
last_timestamp = (datetime.utcnow() - JOB_NEVER_STARTED_DURATION).isoformat()
time_filter = (
f"Timestamp lt datetime'{last_timestamp}' and "
"not(end_time ge datetime'2000-01-11T00:00:00.0Z')"
)
for job in cls.search(
query={
"state": [JobState.enabled],
},
raw_unchecked_filter=time_filter,
):
for task in Task.search(query={"job_id": [job.job_id]}):
task.mark_failed(
Error(
code=ErrorCode.TASK_FAILED,
errors=["job never not start"],
)
)
logging.info(
JOB_LOG_PREFIX + "stopping job that never started: %s", job.job_id
)
job.stopping()
def save_exclude(self) -> Optional[MappingIntStrAny]:
return {"task_info": ...}
def telemetry_include(self) -> Optional[MappingIntStrAny]:
return {
"machine_id": ...,
"state": ...,
"scaleset_id": ...,
}
def init(self) -> None:
logging.info(JOB_LOG_PREFIX + "init: %s", self.job_id)
self.state = JobState.enabled
self.save()
def stop_if_all_done(self) -> None:
not_stopped = [
task
for task in Task.search(query={"job_id": [self.job_id]})
if task.state != TaskState.stopped
]
if not_stopped:
return
logging.info(
JOB_LOG_PREFIX + "stopping job as all tasks are stopped: %s", self.job_id
)
self.stopping()
def stopping(self) -> None:
self.state = JobState.stopping
logging.info(JOB_LOG_PREFIX + "stopping: %s", self.job_id)
tasks = Task.search(query={"job_id": [self.job_id]})
not_stopped = [task for task in tasks if task.state != TaskState.stopped]
if not_stopped:
for task in not_stopped:
task.mark_stopping()
else:
self.state = JobState.stopped
task_info = [
JobTaskStopped(
task_id=x.task_id, error=x.error, task_type=x.config.task.type
)
for x in tasks
]
send_event(
EventJobStopped(
job_id=self.job_id,
config=self.config,
user_info=self.user_info,
task_info=task_info,
)
)
self.save()
def on_start(self) -> None:
# try to keep this effectively idempotent
if self.end_time is None:
self.end_time = datetime.utcnow() + timedelta(hours=self.config.duration)
self.save()
def save(self, new: bool = False, require_etag: bool = False) -> None:
created = self.etag is None
super().save(new=new, require_etag=require_etag)
if created:
send_event(
EventJobCreated(
job_id=self.job_id, config=self.config, user_info=self.user_info
)
)

View File

@ -1,294 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from typing import Iterator, List, Optional, Tuple, Union
from uuid import UUID
from azure.devops.connection import Connection
from azure.devops.credentials import BasicAuthentication
from azure.devops.exceptions import (
AzureDevOpsAuthenticationError,
AzureDevOpsClientError,
AzureDevOpsClientRequestError,
AzureDevOpsServiceError,
)
from azure.devops.v6_0.work_item_tracking.models import (
CommentCreate,
JsonPatchOperation,
Wiql,
WorkItem,
)
from azure.devops.v6_0.work_item_tracking.work_item_tracking_client import (
WorkItemTrackingClient,
)
from memoization import cached
from onefuzztypes.models import ADOTemplate, RegressionReport, Report
from onefuzztypes.primitives import Container
from ..secrets import get_secret_string_value
from .common import Render, log_failed_notification
class AdoNotificationException(Exception):
pass
@cached(ttl=60)
def get_ado_client(base_url: str, token: str) -> WorkItemTrackingClient:
connection = Connection(base_url=base_url, creds=BasicAuthentication("PAT", token))
client = connection.clients_v6_0.get_work_item_tracking_client()
return client
@cached(ttl=60)
def get_valid_fields(
client: WorkItemTrackingClient, project: Optional[str] = None
) -> List[str]:
valid_fields = [
x.reference_name.lower()
for x in client.get_fields(project=project, expand="ExtensionFields")
]
return valid_fields
class ADO:
def __init__(
self,
container: Container,
filename: str,
config: ADOTemplate,
report: Report,
*,
renderer: Optional[Render] = None,
):
self.config = config
if renderer:
self.renderer = renderer
else:
self.renderer = Render(container, filename, report)
self.project = self.render(self.config.project)
def connect(self) -> None:
auth_token = get_secret_string_value(self.config.auth_token)
self.client = get_ado_client(self.config.base_url, auth_token)
def render(self, template: str) -> str:
return self.renderer.render(template)
def existing_work_items(self) -> Iterator[WorkItem]:
filters = {}
for key in self.config.unique_fields:
if key == "System.TeamProject":
value = self.render(self.config.project)
else:
value = self.render(self.config.ado_fields[key])
filters[key.lower()] = value
valid_fields = get_valid_fields(
self.client, project=filters.get("system.teamproject")
)
post_query_filter = {}
# WIQL (Work Item Query Language) is an SQL like query language that
# doesn't support query params, safe quoting, or any other SQL-injection
# protection mechanisms.
#
# As such, build the WIQL with a those fields we can pre-determine are
# "safe" and otherwise use post-query filtering.
parts = []
for k, v in filters.items():
# Only add pre-system approved fields to the query
if k not in valid_fields:
post_query_filter[k] = v
continue
# WIQL supports wrapping values in ' or " and escaping ' by doubling it
#
# For this System.Title: hi'there
# use this query fragment: [System.Title] = 'hi''there'
#
# For this System.Title: hi"there
# use this query fragment: [System.Title] = 'hi"there'
#
# For this System.Title: hi'"there
# use this query fragment: [System.Title] = 'hi''"there'
SINGLE = "'"
parts.append("[%s] = '%s'" % (k, v.replace(SINGLE, SINGLE + SINGLE)))
query = "select [System.Id] from WorkItems"
if parts:
query += " where " + " AND ".join(parts)
wiql = Wiql(query=query)
for entry in self.client.query_by_wiql(wiql).work_items:
item = self.client.get_work_item(entry.id, expand="Fields")
lowered_fields = {x.lower(): str(y) for (x, y) in item.fields.items()}
if post_query_filter and not all(
[
k.lower() in lowered_fields and lowered_fields[k.lower()] == v
for (k, v) in post_query_filter.items()
]
):
continue
yield item
def update_existing(self, item: WorkItem, notification_info: str) -> None:
if self.config.on_duplicate.comment:
comment = self.render(self.config.on_duplicate.comment)
self.client.add_comment(
CommentCreate(comment),
self.project,
item.id,
)
document = []
for field in self.config.on_duplicate.increment:
value = int(item.fields[field]) if field in item.fields else 0
value += 1
document.append(
JsonPatchOperation(
op="Replace", path="/fields/%s" % field, value=str(value)
)
)
for field in self.config.on_duplicate.ado_fields:
field_value = self.render(self.config.on_duplicate.ado_fields[field])
document.append(
JsonPatchOperation(
op="Replace", path="/fields/%s" % field, value=field_value
)
)
if item.fields["System.State"] in self.config.on_duplicate.set_state:
document.append(
JsonPatchOperation(
op="Replace",
path="/fields/System.State",
value=self.config.on_duplicate.set_state[
item.fields["System.State"]
],
)
)
if document:
self.client.update_work_item(document, item.id, project=self.project)
logging.info(
f"notify ado: updated work item {item.id} - {notification_info}"
)
else:
logging.info(
f"notify ado: no update for work item {item.id} - {notification_info}"
)
def render_new(self) -> Tuple[str, List[JsonPatchOperation]]:
task_type = self.render(self.config.type)
document = []
if "System.Tags" not in self.config.ado_fields:
document.append(
JsonPatchOperation(
op="Add", path="/fields/System.Tags", value="Onefuzz"
)
)
for field in self.config.ado_fields:
value = self.render(self.config.ado_fields[field])
if field == "System.Tags":
value += ";Onefuzz"
document.append(
JsonPatchOperation(op="Add", path="/fields/%s" % field, value=value)
)
return (task_type, document)
def create_new(self) -> WorkItem:
task_type, document = self.render_new()
entry = self.client.create_work_item(
document=document, project=self.project, type=task_type
)
if self.config.comment:
comment = self.render(self.config.comment)
self.client.add_comment(
CommentCreate(comment),
self.project,
entry.id,
)
return entry
def process(self, notification_info: str) -> None:
seen = False
for work_item in self.existing_work_items():
self.update_existing(work_item, notification_info)
seen = True
if not seen:
entry = self.create_new()
logging.info(
"notify ado: created new work item" f" {entry.id} - {notification_info}"
)
def is_transient(err: Exception) -> bool:
error_codes = [
# "TF401349: An unexpected error has occurred, please verify your request and try again." # noqa: E501
"TF401349",
# TF26071: This work item has been changed by someone else since you opened it. You will need to refresh it and discard your changes. # noqa: E501
"TF26071",
]
error_str = str(err)
for code in error_codes:
if code in error_str:
return True
return False
def notify_ado(
config: ADOTemplate,
container: Container,
filename: str,
report: Union[Report, RegressionReport],
fail_task_on_transient_error: bool,
notification_id: UUID,
) -> None:
if isinstance(report, RegressionReport):
logging.info(
"ado integration does not support regression reports. "
"container:%s filename:%s",
container,
filename,
)
return
notification_info = (
f"job_id:{report.job_id} task_id:{report.task_id}"
f" container:{container} filename:{filename}"
)
logging.info("notify ado: %s", notification_info)
try:
ado = ADO(container, filename, config, report)
ado.connect()
ado.process(notification_info)
except (
AzureDevOpsAuthenticationError,
AzureDevOpsClientError,
AzureDevOpsServiceError,
AzureDevOpsClientRequestError,
ValueError,
) as err:
if not fail_task_on_transient_error and is_transient(err):
raise AdoNotificationException(
f"transient ADO notification failure {notification_info}"
) from err
else:
log_failed_notification(report, err, notification_id)
raise AdoNotificationException(
"Sending file changed event for notification %s to poison queue"
% notification_id
) from err

View File

@ -1,96 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from typing import Optional
from uuid import UUID
from jinja2.sandbox import SandboxedEnvironment
from onefuzztypes.models import Report
from onefuzztypes.primitives import Container
from ..azure.containers import auth_download_url
from ..azure.creds import get_instance_url
from ..jobs import Job
from ..tasks.config import get_setup_container
from ..tasks.main import Task
def log_failed_notification(
report: Report, error: Exception, notification_id: UUID
) -> None:
logging.error(
"notification failed: notification_id:%s job_id:%s task_id:%s err:%s",
notification_id,
report.job_id,
report.task_id,
error,
)
class Render:
def __init__(
self,
container: Container,
filename: str,
report: Report,
*,
task: Optional[Task] = None,
job: Optional[Job] = None,
target_url: Optional[str] = None,
input_url: Optional[str] = None,
report_url: Optional[str] = None,
):
self.report = report
self.container = container
self.filename = filename
if not task:
task = Task.get(report.job_id, report.task_id)
if not task:
raise ValueError(f"invalid task {report.task_id}")
if not job:
job = Job.get(report.job_id)
if not job:
raise ValueError(f"invalid job {report.job_id}")
self.task_config = task.config
self.job_config = job.config
self.env = SandboxedEnvironment()
self.target_url = target_url
if not self.target_url:
setup_container = get_setup_container(task.config)
if setup_container:
self.target_url = auth_download_url(
setup_container, self.report.executable.replace("setup/", "", 1)
)
if report_url:
self.report_url = report_url
else:
self.report_url = auth_download_url(container, filename)
self.input_url = input_url
if not self.input_url:
if self.report.input_blob:
self.input_url = auth_download_url(
self.report.input_blob.container, self.report.input_blob.name
)
def render(self, template: str) -> str:
return self.env.from_string(template).render(
{
"report": self.report,
"task": self.task_config,
"job": self.job_config,
"report_url": self.report_url,
"input_url": self.input_url,
"target_url": self.target_url,
"report_container": self.container,
"report_filename": self.filename,
"repro_cmd": "onefuzz --endpoint %s repro create_and_connect %s %s"
% (get_instance_url(), self.container, self.filename),
}
)

View File

@ -1,138 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from typing import List, Optional, Union
from uuid import UUID
from github3 import login
from github3.exceptions import GitHubException
from github3.issues import Issue
from onefuzztypes.enums import GithubIssueSearchMatch
from onefuzztypes.models import (
GithubAuth,
GithubIssueTemplate,
RegressionReport,
Report,
)
from onefuzztypes.primitives import Container
from ..secrets import get_secret_obj
from .common import Render, log_failed_notification
class GithubIssue:
def __init__(
self,
config: GithubIssueTemplate,
container: Container,
filename: str,
report: Report,
):
self.config = config
self.report = report
if isinstance(config.auth.secret, GithubAuth):
auth = config.auth.secret
else:
auth = get_secret_obj(config.auth.secret.url, GithubAuth)
self.gh = login(username=auth.user, password=auth.personal_access_token)
self.renderer = Render(container, filename, report)
def render(self, field: str) -> str:
return self.renderer.render(field)
def existing(self) -> List[Issue]:
query = [
self.render(self.config.unique_search.string),
"repo:%s/%s"
% (
self.render(self.config.organization),
self.render(self.config.repository),
),
]
if self.config.unique_search.author:
query.append("author:%s" % self.render(self.config.unique_search.author))
if self.config.unique_search.state:
query.append("state:%s" % self.config.unique_search.state.name)
issues = []
title = self.render(self.config.title)
body = self.render(self.config.body)
for issue in self.gh.search_issues(" ".join(query)):
skip = False
for field in self.config.unique_search.field_match:
if field == GithubIssueSearchMatch.title and issue.title != title:
skip = True
break
if field == GithubIssueSearchMatch.body and issue.body != body:
skip = True
break
if not skip:
issues.append(issue)
return issues
def update(self, issue: Issue) -> None:
logging.info("updating issue: %s", issue)
if self.config.on_duplicate.comment:
issue.issue.create_comment(self.render(self.config.on_duplicate.comment))
if self.config.on_duplicate.labels:
labels = [self.render(x) for x in self.config.on_duplicate.labels]
issue.issue.edit(labels=labels)
if self.config.on_duplicate.reopen and issue.state != "open":
issue.issue.edit(state="open")
def create(self) -> None:
logging.info("creating issue")
assignees = [self.render(x) for x in self.config.assignees]
labels = list(set(["OneFuzz"] + [self.render(x) for x in self.config.labels]))
self.gh.create_issue(
self.render(self.config.organization),
self.render(self.config.repository),
self.render(self.config.title),
body=self.render(self.config.body),
labels=labels,
assignees=assignees,
)
def process(self) -> None:
issues = self.existing()
if issues:
self.update(issues[0])
else:
self.create()
def github_issue(
config: GithubIssueTemplate,
container: Container,
filename: str,
report: Optional[Union[Report, RegressionReport]],
notification_id: UUID,
) -> None:
if report is None:
return
if isinstance(report, RegressionReport):
logging.info(
"github issue integration does not support regression reports. "
"container:%s filename:%s",
container,
filename,
)
return
try:
handler = GithubIssue(config, container, filename, report)
handler.process()
except (GitHubException, ValueError) as err:
log_failed_notification(report, err, notification_id)
raise GitHubException(
"Sending file change event for notification %s to poison queue"
% notification_id
) from err

View File

@ -1,200 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from typing import List, Optional, Sequence, Tuple
from uuid import UUID
from memoization import cached
from onefuzztypes import models
from onefuzztypes.enums import ErrorCode, TaskState
from onefuzztypes.events import (
EventCrashReported,
EventFileAdded,
EventRegressionReported,
)
from onefuzztypes.models import (
ADOTemplate,
Error,
GithubIssueTemplate,
NotificationTemplate,
RegressionReport,
Report,
Result,
TeamsTemplate,
)
from onefuzztypes.primitives import Container
from ..azure.containers import container_exists, get_file_sas_url
from ..azure.queue import send_message
from ..azure.storage import StorageType
from ..events import send_event
from ..orm import ORMMixin
from ..reports import get_report_or_regression
from ..tasks.config import get_input_container_queues
from ..tasks.main import Task
from .ado import notify_ado
from .github_issues import github_issue
from .teams import notify_teams
class Notification(models.Notification, ORMMixin):
@classmethod
def get_by_id(cls, notification_id: UUID) -> Result["Notification"]:
notifications = cls.search(query={"notification_id": [notification_id]})
if not notifications:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["unable to find Notification"]
)
if len(notifications) != 1:
return Error(
code=ErrorCode.INVALID_REQUEST,
errors=["error identifying Notification"],
)
notification = notifications[0]
return notification
@classmethod
def key_fields(cls) -> Tuple[str, str]:
return ("notification_id", "container")
@classmethod
def create(
cls, container: Container, config: NotificationTemplate, replace_existing: bool
) -> Result["Notification"]:
if not container_exists(container, StorageType.corpus):
return Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid container"])
if replace_existing:
existing = cls.search(query={"container": [container]})
for entry in existing:
logging.info(
"replacing existing notification: %s - %s",
entry.notification_id,
container,
)
entry.delete()
entry = cls(container=container, config=config)
entry.save()
logging.info(
"created notification. notification_id:%s container:%s",
entry.notification_id,
entry.container,
)
return entry
@cached(ttl=10)
def get_notifications(container: Container) -> List[Notification]:
return Notification.search(query={"container": [container]})
def get_regression_report_task(report: RegressionReport) -> Optional[Task]:
# crash_test_result is required, but report & no_repro are not
if report.crash_test_result.crash_report:
return Task.get(
report.crash_test_result.crash_report.job_id,
report.crash_test_result.crash_report.task_id,
)
if report.crash_test_result.no_repro:
return Task.get(
report.crash_test_result.no_repro.job_id,
report.crash_test_result.no_repro.task_id,
)
logging.error(
"unable to find crash_report or no_repro entry for report: %s",
report.json(include_none=False),
)
return None
@cached(ttl=10)
def get_queue_tasks() -> Sequence[Tuple[Task, Sequence[str]]]:
results = []
for task in Task.search_states(states=TaskState.available()):
containers = get_input_container_queues(task.config)
if containers:
results.append((task, containers))
return results
def new_files(
container: Container, filename: str, fail_task_on_transient_error: bool
) -> None:
notifications = get_notifications(container)
report = get_report_or_regression(
container, filename, expect_reports=bool(notifications)
)
if notifications:
done = []
for notification in notifications:
# ignore duplicate configurations
if notification.config in done:
continue
done.append(notification.config)
if isinstance(notification.config, TeamsTemplate):
notify_teams(
notification.config,
container,
filename,
report,
notification.notification_id,
)
if not report:
continue
if isinstance(notification.config, ADOTemplate):
notify_ado(
notification.config,
container,
filename,
report,
fail_task_on_transient_error,
notification.notification_id,
)
if isinstance(notification.config, GithubIssueTemplate):
github_issue(
notification.config,
container,
filename,
report,
notification.notification_id,
)
for (task, containers) in get_queue_tasks():
if container in containers:
logging.info("queuing input %s %s %s", container, filename, task.task_id)
url = get_file_sas_url(
container, filename, StorageType.corpus, read=True, delete=True
)
send_message(task.task_id, bytes(url, "utf-8"), StorageType.corpus)
if isinstance(report, Report):
crash_report_event = EventCrashReported(
report=report, container=container, filename=filename
)
report_task = Task.get(report.job_id, report.task_id)
if report_task:
crash_report_event.task_config = report_task.config
send_event(crash_report_event)
elif isinstance(report, RegressionReport):
regression_event = EventRegressionReported(
regression_report=report, container=container, filename=filename
)
report_task = get_regression_report_task(report)
if report_task:
regression_event.task_config = report_task.config
send_event(regression_event)
else:
send_event(EventFileAdded(container=container, filename=filename))

View File

@ -1,141 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from typing import Any, Dict, List, Optional, Union
from uuid import UUID
import requests
from onefuzztypes.models import RegressionReport, Report, TeamsTemplate
from onefuzztypes.primitives import Container
from ..azure.containers import auth_download_url
from ..secrets import get_secret_string_value
from ..tasks.config import get_setup_container
from ..tasks.main import Task
def markdown_escape(data: str) -> str:
values = r"\\*_{}[]()#+-.!" # noqa: P103
for value in values:
data = data.replace(value, "\\" + value)
data = data.replace("`", "``")
return data
def code_block(data: str) -> str:
data = data.replace("`", "``")
return "\n```\n%s\n```\n" % data
def send_teams_webhook(
config: TeamsTemplate,
title: str,
facts: List[Dict[str, str]],
text: Optional[str],
notification_id: UUID,
) -> None:
title = markdown_escape(title)
message: Dict[str, Any] = {
"@type": "MessageCard",
"@context": "https://schema.org/extensions",
"summary": title,
"sections": [{"activityTitle": title, "facts": facts}],
}
if text:
message["sections"].append({"text": text})
config_url = get_secret_string_value(config.url)
response = requests.post(config_url, json=message)
if not response.ok:
logging.error(
"webhook failed notification_id:%s %s %s",
notification_id,
response.status_code,
response.content,
)
def notify_teams(
config: TeamsTemplate,
container: Container,
filename: str,
report: Optional[Union[Report, RegressionReport]],
notification_id: UUID,
) -> None:
text = None
facts: List[Dict[str, str]] = []
if isinstance(report, Report):
task = Task.get(report.job_id, report.task_id)
if not task:
logging.error(
"report with invalid task %s:%s", report.job_id, report.task_id
)
return
title = "new crash in %s: %s @ %s" % (
report.executable,
report.crash_type,
report.crash_site,
)
links = [
"[report](%s)" % auth_download_url(container, filename),
]
setup_container = get_setup_container(task.config)
if setup_container:
links.append(
"[executable](%s)"
% auth_download_url(
setup_container,
report.executable.replace("setup/", "", 1),
),
)
if report.input_blob:
links.append(
"[input](%s)"
% auth_download_url(
report.input_blob.container, report.input_blob.name
),
)
facts += [
{"name": "Files", "value": " | ".join(links)},
{
"name": "Task",
"value": markdown_escape(
"job_id: %s task_id: %s" % (report.job_id, report.task_id)
),
},
{
"name": "Repro",
"value": code_block(
"onefuzz repro create_and_connect %s %s" % (container, filename)
),
},
]
text = "## Call Stack\n" + "\n".join(code_block(x) for x in report.call_stack)
else:
title = "new file found"
facts += [
{
"name": "file",
"value": "[%s/%s](%s)"
% (
markdown_escape(container),
markdown_escape(filename),
auth_download_url(container, filename),
),
}
]
send_teams_webhook(config, title, facts, text, notification_id)

View File

@ -1,505 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import inspect
import json
import logging
from datetime import datetime
from enum import Enum
from typing import (
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from uuid import UUID
from azure.common import AzureConflictHttpError, AzureMissingResourceHttpError
from onefuzztypes.enums import (
ErrorCode,
JobState,
NodeState,
PoolState,
ScalesetState,
TaskState,
TelemetryEvent,
UpdateType,
VmState,
)
from onefuzztypes.models import Error, SecretData
from onefuzztypes.primitives import Container, PoolName, Region
from pydantic import BaseModel
from typing_extensions import Protocol
from .azure.table import get_client
from .secrets import delete_remote_secret_data, save_to_keyvault
from .telemetry import track_event_filtered
from .updates import queue_update
A = TypeVar("A", bound="ORMMixin")
QUERY_VALUE_TYPES = Union[
List[bool],
List[int],
List[str],
List[UUID],
List[Region],
List[Container],
List[PoolName],
List[VmState],
List[ScalesetState],
List[JobState],
List[TaskState],
List[PoolState],
List[NodeState],
]
QueryFilter = Dict[str, QUERY_VALUE_TYPES]
SAFE_STRINGS = (UUID, Container, Region, PoolName)
KEY = Union[int, str, UUID, Enum]
HOURS = 60 * 60
class HasState(Protocol):
# TODO: this should be bound tighter than Any
# In the end, we want this to be an Enum. Specifically, one of
# the JobState,TaskState,etc enums.
state: Any
def get_keys(self) -> Tuple[KEY, KEY]:
...
def process_state_update(obj: HasState) -> None:
"""
process a single state update, if the obj
implements a function for that state
"""
func = getattr(obj, obj.state.name, None)
if func is None:
return
keys = obj.get_keys()
logging.info(
"processing state update: %s - %s - %s", type(obj), keys, obj.state.name
)
func()
def process_state_updates(obj: HasState, max_updates: int = 5) -> None:
"""process through the state machine for an object"""
for _ in range(max_updates):
state = obj.state
process_state_update(obj)
new_state = obj.state
if new_state == state:
break
def resolve(key: KEY) -> str:
if isinstance(key, str):
return key
elif isinstance(key, UUID):
return str(key)
elif isinstance(key, Enum):
return key.name
elif isinstance(key, int):
return str(key)
raise NotImplementedError("unsupported type %s - %s" % (type(key), repr(key)))
def build_filters(
cls: Type[A], query_args: Optional[QueryFilter]
) -> Tuple[Optional[str], QueryFilter]:
if not query_args:
return (None, {})
partition_key_field, row_key_field = cls.key_fields()
search_filter_parts = []
post_filters: QueryFilter = {}
for field, values in query_args.items():
if field not in cls.__fields__:
raise ValueError("unexpected field %s: %s" % (repr(field), cls))
if not values:
continue
if field == partition_key_field:
field_name = "PartitionKey"
elif field == row_key_field:
field_name = "RowKey"
else:
field_name = field
parts: Optional[List[str]] = None
if isinstance(values[0], bool):
parts = []
for x in values:
if not isinstance(x, bool):
raise TypeError("unexpected type")
parts.append("%s eq %s" % (field_name, str(x).lower()))
elif isinstance(values[0], int):
parts = []
for x in values:
if not isinstance(x, int):
raise TypeError("unexpected type")
parts.append("%s eq %d" % (field_name, x))
elif isinstance(values[0], Enum):
parts = []
for x in values:
if not isinstance(x, Enum):
raise TypeError("unexpected type")
parts.append("%s eq '%s'" % (field_name, x.name))
elif all(isinstance(x, SAFE_STRINGS) for x in values):
parts = ["%s eq '%s'" % (field_name, x) for x in values]
else:
post_filters[field_name] = values
if parts:
if len(parts) == 1:
search_filter_parts.append(parts[0])
else:
search_filter_parts.append("(" + " or ".join(parts) + ")")
if search_filter_parts:
return (" and ".join(search_filter_parts), post_filters)
return (None, post_filters)
def post_filter(value: Any, filters: Optional[QueryFilter]) -> bool:
if not filters:
return True
for field in filters:
if field not in value:
return False
if value[field] not in filters[field]:
return False
return True
MappingIntStrAny = Mapping[Union[int, str], Any]
# A = TypeVar("A", bound="Model")
class ModelMixin(BaseModel):
def export_exclude(self) -> Optional[MappingIntStrAny]:
return None
def raw(
self,
*,
by_alias: bool = False,
exclude_none: bool = False,
exclude: MappingIntStrAny = None,
include: MappingIntStrAny = None,
) -> Dict[str, Any]:
# cycling through json means all wrapped types get resolved, such as UUID
result: Dict[str, Any] = json.loads(
self.json(
by_alias=by_alias,
exclude_none=exclude_none,
exclude=exclude,
include=include,
)
)
return result
B = TypeVar("B", bound=BaseModel)
def hide_secrets(data: B, hider: Callable[[SecretData], SecretData]) -> B:
for field in data.__fields__:
field_data = getattr(data, field)
if isinstance(field_data, SecretData):
field_data = hider(field_data)
elif isinstance(field_data, BaseModel):
field_data = hide_secrets(field_data, hider)
elif isinstance(field_data, list):
field_data = [
hide_secrets(x, hider) if isinstance(x, BaseModel) else x
for x in field_data
]
elif isinstance(field_data, dict):
for key in field_data:
if not isinstance(field_data[key], BaseModel):
continue
field_data[key] = hide_secrets(field_data[key], hider)
setattr(data, field, field_data)
return data
# NOTE: the actual deletion must come from the `deleter` callback function
def delete_secrets(data: B, deleter: Callable[[SecretData], None]) -> None:
for field in data.__fields__:
field_data = getattr(data, field)
if isinstance(field_data, SecretData):
deleter(field_data)
elif isinstance(field_data, BaseModel):
delete_secrets(field_data, deleter)
elif isinstance(field_data, list):
for entry in field_data:
if isinstance(entry, BaseModel):
delete_secrets(entry, deleter)
elif isinstance(entry, SecretData):
deleter(entry)
elif isinstance(field_data, dict):
for value in field_data.values():
if isinstance(value, BaseModel):
delete_secrets(value, deleter)
elif isinstance(value, SecretData):
deleter(value)
# NOTE: if you want to include Timestamp in a model that uses ORMMixin,
# it must be maintained as part of the model.
class ORMMixin(ModelMixin):
etag: Optional[str]
@classmethod
def table_name(cls: Type[A]) -> str:
return cls.__name__
@classmethod
def get(
cls: Type[A], PartitionKey: KEY, RowKey: Optional[KEY] = None
) -> Optional[A]:
client = get_client(table=cls.table_name())
partition_key = resolve(PartitionKey)
row_key = resolve(RowKey) if RowKey else partition_key
try:
raw = client.get_entity(cls.table_name(), partition_key, row_key)
except AzureMissingResourceHttpError:
return None
return cls.load(raw)
@classmethod
def key_fields(cls) -> Tuple[str, Optional[str]]:
raise NotImplementedError("keys not defined")
# FILTERS:
# The following
# * save_exclude: Specify fields to *exclude* from saving to Storage Tables
# * export_exclude: Specify the fields to *exclude* from sending to an external API
# * telemetry_include: Specify the fields to *include* for telemetry
#
# For implementation details see:
# https://pydantic-docs.helpmanual.io/usage/exporting_models/#advanced-include-and-exclude
def save_exclude(self) -> Optional[MappingIntStrAny]:
return None
def export_exclude(self) -> Optional[MappingIntStrAny]:
return {"etag": ...}
def telemetry_include(self) -> Optional[MappingIntStrAny]:
return {}
def telemetry(self) -> Any:
return self.raw(exclude_none=True, include=self.telemetry_include())
def get_keys(self) -> Tuple[KEY, KEY]:
partition_key_field, row_key_field = self.key_fields()
partition_key = getattr(self, partition_key_field)
if row_key_field:
row_key = getattr(self, row_key_field)
else:
row_key = partition_key
return (partition_key, row_key)
def save(self, new: bool = False, require_etag: bool = False) -> Optional[Error]:
self = hide_secrets(self, save_to_keyvault)
# TODO: migrate to an inspect.signature() model
raw = self.raw(by_alias=True, exclude_none=True, exclude=self.save_exclude())
for key in raw:
if not isinstance(raw[key], (str, int)):
raw[key] = json.dumps(raw[key])
for field in self.__fields__:
if field not in raw:
continue
# for datetime fields that passed through filtering, use the real value,
# rather than a serialized form
if self.__fields__[field].type_ == datetime:
raw[field] = getattr(self, field)
partition_key_field, row_key_field = self.key_fields()
# PartitionKey and RowKey must be 'str'
raw["PartitionKey"] = resolve(raw[partition_key_field])
raw["RowKey"] = resolve(raw[row_key_field or partition_key_field])
del raw[partition_key_field]
if row_key_field in raw:
del raw[row_key_field]
client = get_client(table=self.table_name())
# never save the timestamp
if "Timestamp" in raw:
del raw["Timestamp"]
if new:
try:
self.etag = client.insert_entity(self.table_name(), raw)
except AzureConflictHttpError:
return Error(
code=ErrorCode.UNABLE_TO_CREATE, errors=["entry already exists"]
)
elif self.etag and require_etag:
self.etag = client.replace_entity(
self.table_name(), raw, if_match=self.etag
)
else:
self.etag = client.insert_or_replace_entity(self.table_name(), raw)
if self.table_name() in TelemetryEvent.__members__:
telem = self.telemetry()
if telem:
track_event_filtered(TelemetryEvent[self.table_name()], telem)
return None
def delete(self) -> None:
partition_key, row_key = self.get_keys()
delete_secrets(self, delete_remote_secret_data)
client = get_client()
try:
client.delete_entity(
self.table_name(), resolve(partition_key), resolve(row_key)
)
except AzureMissingResourceHttpError:
# It's OK if the component is already deleted
pass
@classmethod
def load(cls: Type[A], data: Dict[str, Union[str, bytes, bytearray]]) -> A:
partition_key_field, row_key_field = cls.key_fields()
if partition_key_field in data:
raise Exception(
"duplicate PartitionKey field %s for %s"
% (partition_key_field, cls.table_name())
)
if row_key_field in data:
raise Exception(
"duplicate RowKey field %s for %s" % (row_key_field, cls.table_name())
)
data[partition_key_field] = data["PartitionKey"]
if row_key_field is not None:
data[row_key_field] = data["RowKey"]
del data["PartitionKey"]
del data["RowKey"]
for key in inspect.signature(cls).parameters:
if key not in data:
continue
annotation = inspect.signature(cls).parameters[key].annotation
if inspect.isclass(annotation):
if (
issubclass(annotation, BaseModel)
or issubclass(annotation, dict)
or issubclass(annotation, list)
):
data[key] = json.loads(data[key])
continue
if getattr(annotation, "__origin__", None) == Union and any(
inspect.isclass(x) and issubclass(x, BaseModel)
for x in annotation.__args__
):
data[key] = json.loads(data[key])
continue
# Required for Python >=3.7. In 3.6, a `Dict[_,_]` and `List[_]` annotations
# are a class according to `inspect.isclass`.
if getattr(annotation, "__origin__", None) in [dict, list]:
data[key] = json.loads(data[key])
continue
return cls.parse_obj(data)
@classmethod
def search(
cls: Type[A],
*,
query: Optional[QueryFilter] = None,
raw_unchecked_filter: Optional[str] = None,
num_results: int = None,
) -> List[A]:
search_filter, post_filters = build_filters(cls, query)
if raw_unchecked_filter is not None:
if search_filter is None:
search_filter = raw_unchecked_filter
else:
search_filter = "(%s) and (%s)" % (search_filter, raw_unchecked_filter)
client = get_client(table=cls.table_name())
entries = []
for row in client.query_entities(
cls.table_name(), filter=search_filter, num_results=num_results
):
if not post_filter(row, post_filters):
continue
entry = cls.load(row)
entries.append(entry)
return entries
def queue(
self,
*,
method: Optional[Callable] = None,
visibility_timeout: Optional[int] = None,
) -> None:
if not hasattr(self, "state"):
raise NotImplementedError("Queued an ORM mapping without State")
update_type = UpdateType.__members__.get(type(self).__name__)
if update_type is None:
raise NotImplementedError("unsupported update type: %s" % self)
method_name: Optional[str] = None
if method is not None:
if not hasattr(method, "__name__"):
raise Exception("unable to queue method: %s" % method)
method_name = method.__name__
partition_key, row_key = self.get_keys()
queue_update(
update_type,
resolve(partition_key),
resolve(row_key),
method=method_name,
visibility_timeout=visibility_timeout,
)

View File

@ -1,346 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import datetime
import logging
import os
from typing import List, Optional, Tuple
from uuid import UUID, uuid4
import base58
from azure.mgmt.compute.models import VirtualMachine
from onefuzztypes.enums import ErrorCode, VmState
from onefuzztypes.events import (
EventProxyCreated,
EventProxyDeleted,
EventProxyFailed,
EventProxyStateUpdated,
)
from onefuzztypes.models import (
Authentication,
Error,
Forward,
ProxyConfig,
ProxyHeartbeat,
)
from onefuzztypes.primitives import Container, Region
from pydantic import Field
from .__version__ import __version__
from .azure.auth import build_auth
from .azure.containers import get_file_sas_url, save_blob
from .azure.creds import get_instance_id
from .azure.ip import get_public_ip
from .azure.nsg import NSG
from .azure.queue import get_queue_sas
from .azure.storage import StorageType
from .azure.vm import VM
from .config import InstanceConfig
from .events import send_event
from .extension import proxy_manager_extensions
from .orm import ORMMixin, QueryFilter
from .proxy_forward import ProxyForward
PROXY_LOG_PREFIX = "scaleset-proxy: "
PROXY_LIFESPAN = datetime.timedelta(days=7)
# This isn't intended to ever be shared to the client, hence not being in
# onefuzztypes
class Proxy(ORMMixin):
timestamp: Optional[datetime.datetime] = Field(alias="Timestamp")
created_timestamp: datetime.datetime = Field(
default_factory=datetime.datetime.utcnow
)
proxy_id: UUID = Field(default_factory=uuid4)
region: Region
state: VmState = Field(default=VmState.init)
auth: Authentication = Field(default_factory=build_auth)
ip: Optional[str]
error: Optional[Error]
version: str = Field(default=__version__)
heartbeat: Optional[ProxyHeartbeat]
outdated: bool = Field(default=False)
@classmethod
def key_fields(cls) -> Tuple[str, Optional[str]]:
return ("region", "proxy_id")
def get_vm(self, config: InstanceConfig) -> VM:
config = InstanceConfig.fetch()
sku = config.proxy_vm_sku
tags = None
if config.vm_tags:
tags = config.vm_tags
vm = VM(
name="proxy-%s" % base58.b58encode(self.proxy_id.bytes).decode(),
region=self.region,
sku=sku,
image=config.default_linux_vm_image,
auth=self.auth,
tags=tags,
)
return vm
def init(self) -> None:
config = InstanceConfig.fetch()
vm = self.get_vm(config)
vm_data = vm.get()
if vm_data:
if vm_data.provisioning_state == "Failed":
self.set_provision_failed(vm_data)
return
else:
self.save_proxy_config()
self.set_state(VmState.extensions_launch)
else:
nsg = NSG(
name=self.region,
region=self.region,
)
result = nsg.create()
if isinstance(result, Error):
self.set_failed(result)
return
nsg_config = config.proxy_nsg_config
result = nsg.set_allowed_sources(nsg_config)
if isinstance(result, Error):
self.set_failed(result)
return
vm.nsg = nsg
result = vm.create()
if isinstance(result, Error):
self.set_failed(result)
return
self.save()
def set_provision_failed(self, vm_data: VirtualMachine) -> None:
errors = ["provisioning failed"]
for status in vm_data.instance_view.statuses:
if status.level.name.lower() == "error":
errors.append(
f"code:{status.code} status:{status.display_status} "
f"message:{status.message}"
)
self.set_failed(
Error(
code=ErrorCode.PROXY_FAILED,
errors=errors,
)
)
return
def set_failed(self, error: Error) -> None:
if self.error is not None:
return
logging.error(PROXY_LOG_PREFIX + "vm failed: %s - %s", self.region, error)
send_event(
EventProxyFailed(region=self.region, proxy_id=self.proxy_id, error=error)
)
self.error = error
self.set_state(VmState.stopping)
def extensions_launch(self) -> None:
config = InstanceConfig.fetch()
vm = self.get_vm(config)
vm_data = vm.get()
if not vm_data:
self.set_failed(
Error(
code=ErrorCode.PROXY_FAILED,
errors=["azure not able to find vm"],
)
)
return
if vm_data.provisioning_state == "Failed":
self.set_provision_failed(vm_data)
return
ip = get_public_ip(vm_data.network_profile.network_interfaces[0].id)
if ip is None:
self.save()
return
self.ip = ip
extensions = proxy_manager_extensions(self.region, self.proxy_id)
result = vm.add_extensions(extensions)
if isinstance(result, Error):
self.set_failed(result)
return
elif result:
self.set_state(VmState.running)
self.save()
def stopping(self) -> None:
config = InstanceConfig.fetch()
vm = self.get_vm(config)
if not vm.is_deleted():
logging.info(PROXY_LOG_PREFIX + "stopping proxy: %s", self.region)
vm.delete()
self.save()
else:
self.stopped()
def stopped(self) -> None:
self.set_state(VmState.stopped)
logging.info(PROXY_LOG_PREFIX + "removing proxy: %s", self.region)
send_event(EventProxyDeleted(region=self.region, proxy_id=self.proxy_id))
self.delete()
def is_outdated(self) -> bool:
if self.state not in VmState.available():
return True
if self.version != __version__:
logging.info(
PROXY_LOG_PREFIX + "mismatch version: proxy:%s service:%s state:%s",
self.version,
__version__,
self.state,
)
return True
if self.created_timestamp is not None:
proxy_timestamp = self.created_timestamp
if proxy_timestamp < (
datetime.datetime.now(tz=datetime.timezone.utc) - PROXY_LIFESPAN
):
logging.info(
PROXY_LOG_PREFIX
+ "proxy older than 7 days:proxy-created:%s state:%s",
self.created_timestamp,
self.state,
)
return True
return False
def is_used(self) -> bool:
if len(self.get_forwards()) == 0:
logging.info(PROXY_LOG_PREFIX + "no forwards: %s", self.region)
return False
return True
def is_alive(self) -> bool:
# Unfortunately, with and without TZ information is required for compare
# or exceptions are generated
ten_minutes_ago_no_tz = datetime.datetime.utcnow() - datetime.timedelta(
minutes=10
)
ten_minutes_ago = ten_minutes_ago_no_tz.astimezone(datetime.timezone.utc)
if (
self.heartbeat is not None
and self.heartbeat.timestamp < ten_minutes_ago_no_tz
):
logging.error(
PROXY_LOG_PREFIX + "last heartbeat is more than an 10 minutes old: "
"%s - last heartbeat:%s compared_to:%s",
self.region,
self.heartbeat,
ten_minutes_ago_no_tz,
)
return False
elif not self.heartbeat and self.timestamp and self.timestamp < ten_minutes_ago:
logging.error(
PROXY_LOG_PREFIX + "no heartbeat in the last 10 minutes: "
"%s timestamp: %s compared_to:%s",
self.region,
self.timestamp,
ten_minutes_ago,
)
return False
return True
def get_forwards(self) -> List[Forward]:
forwards: List[Forward] = []
for entry in ProxyForward.search_forward(
region=self.region, proxy_id=self.proxy_id
):
if entry.endtime < datetime.datetime.now(tz=datetime.timezone.utc):
entry.delete()
else:
forwards.append(
Forward(
src_port=entry.port,
dst_ip=entry.dst_ip,
dst_port=entry.dst_port,
)
)
return forwards
def save_proxy_config(self) -> None:
forwards = self.get_forwards()
proxy_config = ProxyConfig(
url=get_file_sas_url(
Container("proxy-configs"),
"%s/%s/config.json" % (self.region, self.proxy_id),
StorageType.config,
read=True,
),
notification=get_queue_sas(
"proxy",
StorageType.config,
add=True,
),
forwards=forwards,
region=self.region,
proxy_id=self.proxy_id,
instance_telemetry_key=os.environ.get("APPINSIGHTS_INSTRUMENTATIONKEY"),
microsoft_telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"),
instance_id=get_instance_id(),
)
save_blob(
Container("proxy-configs"),
"%s/%s/config.json" % (self.region, self.proxy_id),
proxy_config.json(),
StorageType.config,
)
@classmethod
def search_states(cls, *, states: Optional[List[VmState]] = None) -> List["Proxy"]:
query: QueryFilter = {}
if states:
query["state"] = states
return cls.search(query=query)
@classmethod
def get_or_create(cls, region: Region) -> Optional["Proxy"]:
proxy_list = Proxy.search(query={"region": [region], "outdated": [False]})
for proxy in proxy_list:
if proxy.is_outdated():
proxy.outdated = True
proxy.save()
continue
if proxy.state not in VmState.available():
continue
return proxy
logging.info(PROXY_LOG_PREFIX + "creating proxy: region:%s", region)
proxy = Proxy(region=region)
proxy.save()
send_event(EventProxyCreated(region=region, proxy_id=proxy.proxy_id))
return proxy
def set_state(self, state: VmState) -> None:
if self.state == state:
return
self.state = state
self.save()
send_event(
EventProxyStateUpdated(
region=self.region, proxy_id=self.proxy_id, state=self.state
)
)

View File

@ -1,143 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import datetime
import logging
from typing import List, Optional, Tuple, Union
from uuid import UUID
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error, Forward
from onefuzztypes.primitives import Region
from pydantic import Field
from .azure.ip import get_scaleset_instance_ip
from .orm import ORMMixin, QueryFilter
PORT_RANGES = range(28000, 32000)
# This isn't intended to ever be shared to the client, hence not being in
# onefuzztypes
class ProxyForward(ORMMixin):
region: Region
port: int
scaleset_id: UUID
machine_id: UUID
proxy_id: Optional[UUID]
dst_ip: str
dst_port: int
endtime: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
@classmethod
def key_fields(cls) -> Tuple[str, str]:
return ("region", "port")
@classmethod
def update_or_create(
cls,
region: Region,
scaleset_id: UUID,
machine_id: UUID,
dst_port: int,
duration: int,
) -> Union["ProxyForward", Error]:
private_ip = get_scaleset_instance_ip(scaleset_id, machine_id)
if not private_ip:
return Error(
code=ErrorCode.UNABLE_TO_PORT_FORWARD, errors=["no private ip for node"]
)
entries = cls.search_forward(
scaleset_id=scaleset_id,
machine_id=machine_id,
dst_port=dst_port,
region=region,
)
if entries:
entry = entries[0]
entry.endtime = datetime.datetime.utcnow() + datetime.timedelta(
hours=duration
)
entry.save()
return entry
existing = [int(x.port) for x in entries]
for port in PORT_RANGES:
if port in existing:
continue
entry = cls(
region=region,
port=port,
scaleset_id=scaleset_id,
machine_id=machine_id,
dst_ip=private_ip,
dst_port=dst_port,
endtime=datetime.datetime.utcnow() + datetime.timedelta(hours=duration),
)
result = entry.save(new=True)
if isinstance(result, Error):
logging.info("port is already used: %s", entry)
continue
return entry
return Error(
code=ErrorCode.UNABLE_TO_PORT_FORWARD, errors=["all forward ports used"]
)
@classmethod
def remove_forward(
cls,
scaleset_id: UUID,
*,
proxy_id: Optional[UUID] = None,
machine_id: Optional[UUID] = None,
dst_port: Optional[int] = None,
) -> List[Region]:
entries = cls.search_forward(
scaleset_id=scaleset_id,
machine_id=machine_id,
proxy_id=proxy_id,
dst_port=dst_port,
)
regions = set()
for entry in entries:
regions.add(entry.region)
entry.delete()
return list(regions)
@classmethod
def search_forward(
cls,
*,
scaleset_id: Optional[UUID] = None,
region: Optional[Region] = None,
machine_id: Optional[UUID] = None,
proxy_id: Optional[UUID] = None,
dst_port: Optional[int] = None,
) -> List["ProxyForward"]:
query: QueryFilter = {}
if region is not None:
query["region"] = [region]
if scaleset_id is not None:
query["scaleset_id"] = [scaleset_id]
if machine_id is not None:
query["machine_id"] = [machine_id]
if proxy_id is not None:
query["proxy_id"] = [proxy_id]
if dst_port is not None:
query["dst_port"] = [dst_port]
return cls.search(query=query)
def to_forward(self) -> Forward:
return Forward(src_port=self.port, dst_ip=self.dst_ip, dst_port=self.dst_port)

View File

@ -1,165 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import json
import logging
from sys import getsizeof
from typing import Optional, Union
from memoization import cached
from onefuzztypes.models import RegressionReport, Report
from onefuzztypes.primitives import Container
from pydantic import ValidationError
from .azure.containers import get_blob
from .azure.storage import StorageType
# This is fix for the following error:
# Exception while executing function:
# Functions.queue_file_changes Result: Failure
# Exception: AzureHttpError: Bad Request
# "The property value exceeds the maximum allowed size (64KB).
# If the property value is a string, it is UTF-16 encoded and
# the maximum number of characters should be 32K or less.
def fix_report_size(
content: str,
report: Report,
acceptable_report_length_kb: int = 24,
keep_num_entries: int = 10,
keep_string_len: int = 256,
) -> Report:
logging.info(f"report content length {getsizeof(content)}")
if getsizeof(content) > acceptable_report_length_kb * 1024:
msg = f"report data exceeds {acceptable_report_length_kb}K {getsizeof(content)}"
if len(report.call_stack) > keep_num_entries:
msg = msg + "; removing some of stack frames from the report"
report.call_stack = report.call_stack[0:keep_num_entries] + ["..."]
if report.asan_log and len(report.asan_log) > keep_string_len:
msg = msg + "; removing some of asan log entries from the report"
report.asan_log = report.asan_log[0:keep_string_len] + "..."
if report.minimized_stack and len(report.minimized_stack) > keep_num_entries:
msg = msg + "; removing some of minimized stack frames from the report"
report.minimized_stack = report.minimized_stack[0:keep_num_entries] + [
"..."
]
if (
report.minimized_stack_function_names
and len(report.minimized_stack_function_names) > keep_num_entries
):
msg = (
msg
+ "; removing some of minimized stack function names from the report"
)
report.minimized_stack_function_names = (
report.minimized_stack_function_names[0:keep_num_entries] + ["..."]
)
if (
report.minimized_stack_function_lines
and len(report.minimized_stack_function_lines) > keep_num_entries
):
msg = (
msg
+ "; removing some of minimized stack function lines from the report"
)
report.minimized_stack_function_lines = (
report.minimized_stack_function_lines[0:keep_num_entries] + ["..."]
)
logging.info(msg)
return report
def parse_report_or_regression(
content: Union[str, bytes],
file_path: Optional[str] = None,
expect_reports: bool = False,
) -> Optional[Union[Report, RegressionReport]]:
if isinstance(content, bytes):
try:
content = content.decode()
except UnicodeDecodeError as err:
if expect_reports:
logging.error(
f"unable to parse report ({file_path}): "
f"unicode decode of report failed - {err}"
)
return None
try:
data = json.loads(content)
except json.decoder.JSONDecodeError as err:
if expect_reports:
logging.error(
f"unable to parse report ({file_path}): json decoding failed - {err}"
)
return None
regression_err = None
try:
regression_report = RegressionReport.parse_obj(data)
if (
regression_report.crash_test_result is not None
and regression_report.crash_test_result.crash_report is not None
):
regression_report.crash_test_result.crash_report = fix_report_size(
content, regression_report.crash_test_result.crash_report
)
if (
regression_report.original_crash_test_result is not None
and regression_report.original_crash_test_result.crash_report is not None
):
regression_report.original_crash_test_result.crash_report = fix_report_size(
content, regression_report.original_crash_test_result.crash_report
)
return regression_report
except ValidationError as err:
regression_err = err
try:
report = Report.parse_obj(data)
return fix_report_size(content, report)
except ValidationError as err:
if expect_reports:
logging.error(
f"unable to parse report ({file_path}) as a report or regression. "
f"regression error: {regression_err} report error: {err}"
)
return None
# cache the last 1000 reports
@cached(max_size=1000)
def get_report_or_regression(
container: Container, filename: str, *, expect_reports: bool = False
) -> Optional[Union[Report, RegressionReport]]:
file_path = "/".join([container, filename])
if not filename.endswith(".json"):
if expect_reports:
logging.error("get_report invalid extension: %s", file_path)
return None
blob = get_blob(container, filename, StorageType.corpus)
if blob is None:
if expect_reports:
logging.error("get_report invalid blob: %s", file_path)
return None
return parse_report_or_regression(
blob, file_path=file_path, expect_reports=expect_reports
)
def get_report(container: Container, filename: str) -> Optional[Report]:
result = get_report_or_regression(container, filename)
if isinstance(result, Report):
return result
return None

View File

@ -1,286 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from datetime import datetime, timedelta
from typing import List, Optional, Tuple, Union
from azure.mgmt.compute.models import VirtualMachine
from onefuzztypes.enums import OS, ContainerType, ErrorCode, VmState
from onefuzztypes.models import Error
from onefuzztypes.models import Repro as BASE_REPRO
from onefuzztypes.models import ReproConfig, TaskVm, UserInfo
from onefuzztypes.primitives import Container
from .azure.auth import build_auth
from .azure.containers import save_blob
from .azure.creds import get_base_region
from .azure.ip import get_public_ip
from .azure.nsg import NSG
from .azure.storage import StorageType
from .azure.vm import VM
from .config import InstanceConfig
from .extension import repro_extensions
from .orm import ORMMixin, QueryFilter
from .reports import get_report
from .tasks.main import Task
DEFAULT_SKU = "Standard_DS1_v2"
class Repro(BASE_REPRO, ORMMixin):
def set_error(self, error: Error) -> None:
logging.error(
"repro failed: vm_id: %s task_id: %s: error: %s",
self.vm_id,
self.task_id,
error,
)
self.error = error
self.state = VmState.stopping
self.save()
def get_vm(self, config: InstanceConfig) -> VM:
tags = None
if config.vm_tags:
tags = config.vm_tags
task = Task.get_by_task_id(self.task_id)
if isinstance(task, Error):
raise Exception("previously existing task missing: %s" % self.task_id)
config = InstanceConfig.fetch()
default_os = {
OS.linux: config.default_linux_vm_image,
OS.windows: config.default_windows_vm_image,
}
vm_config = task.get_repro_vm_config()
if vm_config is None:
# if using a pool without any scalesets defined yet, use reasonable defaults
if task.os not in default_os:
raise NotImplementedError("unsupported OS for repro %s" % task.os)
vm_config = TaskVm(
region=get_base_region(), sku=DEFAULT_SKU, image=default_os[task.os]
)
if self.auth is None:
raise Exception("missing auth")
return VM(
name=self.vm_id,
region=vm_config.region,
sku=vm_config.sku,
image=vm_config.image,
auth=self.auth,
tags=tags,
)
def init(self) -> None:
config = InstanceConfig.fetch()
vm = self.get_vm(config)
vm_data = vm.get()
if vm_data:
if vm_data.provisioning_state == "Failed":
self.set_failed(vm)
else:
script_result = self.build_repro_script()
if isinstance(script_result, Error):
self.set_error(script_result)
return
self.state = VmState.extensions_launch
else:
nsg = NSG(
name=vm.region,
region=vm.region,
)
result = nsg.create()
if isinstance(result, Error):
self.set_failed(result)
return
nsg_config = config.proxy_nsg_config
result = nsg.set_allowed_sources(nsg_config)
if isinstance(result, Error):
self.set_failed(result)
return
vm.nsg = nsg
result = vm.create()
if isinstance(result, Error):
self.set_error(result)
return
self.save()
def set_failed(self, vm_data: VirtualMachine) -> None:
errors = []
for status in vm_data.instance_view.statuses:
if status.level.name.lower() == "error":
errors.append(
"%s %s %s" % (status.code, status.display_status, status.message)
)
return self.set_error(Error(code=ErrorCode.VM_CREATE_FAILED, errors=errors))
def get_setup_container(self) -> Optional[Container]:
task = Task.get_by_task_id(self.task_id)
if isinstance(task, Task):
for container in task.config.containers:
if container.type == ContainerType.setup:
return container.name
return None
def extensions_launch(self) -> None:
config = InstanceConfig.fetch()
vm = self.get_vm(config)
vm_data = vm.get()
if not vm_data:
self.set_error(
Error(
code=ErrorCode.VM_CREATE_FAILED,
errors=["failed before launching extensions"],
)
)
return
if vm_data.provisioning_state == "Failed":
self.set_failed(vm_data)
return
if not self.ip:
self.ip = get_public_ip(vm_data.network_profile.network_interfaces[0].id)
extensions = repro_extensions(
vm.region, self.os, self.vm_id, self.config, self.get_setup_container()
)
result = vm.add_extensions(extensions)
if isinstance(result, Error):
self.set_error(result)
return
elif result:
self.state = VmState.running
self.save()
def stopping(self) -> None:
config = InstanceConfig.fetch()
vm = self.get_vm(config)
if not vm.is_deleted():
logging.info("vm stopping: %s", self.vm_id)
vm.delete()
self.save()
else:
self.stopped()
def stopped(self) -> None:
logging.info("vm stopped: %s", self.vm_id)
self.delete()
def build_repro_script(self) -> Optional[Error]:
if self.auth is None:
return Error(code=ErrorCode.VM_CREATE_FAILED, errors=["missing auth"])
task = Task.get_by_task_id(self.task_id)
if isinstance(task, Error):
return task
report = get_report(self.config.container, self.config.path)
if report is None:
return Error(code=ErrorCode.VM_CREATE_FAILED, errors=["missing report"])
if report.input_blob is None:
return Error(
code=ErrorCode.VM_CREATE_FAILED,
errors=["unable to perform repro for crash reports without inputs"],
)
files = {}
if task.os == OS.windows:
ssh_path = "$env:ProgramData/ssh/administrators_authorized_keys"
cmds = [
'Set-Content -Path %s -Value "%s"' % (ssh_path, self.auth.public_key),
". C:\\onefuzz\\tools\\win64\\onefuzz.ps1",
"Set-SetSSHACL",
'while (1) { cdb -server tcp:port=1337 -c "g" setup\\%s %s }'
% (
task.config.task.target_exe,
report.input_blob.name,
),
]
cmd = "\r\n".join(cmds)
files["repro.ps1"] = cmd
elif task.os == OS.linux:
gdb_fmt = (
"ASAN_OPTIONS='abort_on_error=1' gdbserver "
"%s /onefuzz/setup/%s /onefuzz/downloaded/%s"
)
cmd = "while :; do %s; done" % (
gdb_fmt
% (
"localhost:1337",
task.config.task.target_exe,
report.input_blob.name,
)
)
files["repro.sh"] = cmd
cmd = "#!/bin/bash\n%s" % (
gdb_fmt % ("-", task.config.task.target_exe, report.input_blob.name)
)
files["repro-stdout.sh"] = cmd
else:
raise NotImplementedError("invalid task os: %s" % task.os)
for filename in files:
save_blob(
Container("repro-scripts"),
"%s/%s" % (self.vm_id, filename),
files[filename],
StorageType.config,
)
logging.info("saved repro script")
return None
@classmethod
def search_states(cls, *, states: Optional[List[VmState]] = None) -> List["Repro"]:
query: QueryFilter = {}
if states:
query["state"] = states
return cls.search(query=query)
@classmethod
def create(
cls, config: ReproConfig, user_info: Optional[UserInfo]
) -> Union[Error, "Repro"]:
report = get_report(config.container, config.path)
if not report:
return Error(
code=ErrorCode.UNABLE_TO_FIND, errors=["unable to find report"]
)
task = Task.get_by_task_id(report.task_id)
if isinstance(task, Error):
return task
vm = cls(config=config, task_id=task.task_id, os=task.os, auth=build_auth())
if vm.end_time is None:
vm.end_time = datetime.utcnow() + timedelta(hours=config.duration)
vm.user_info = user_info
vm.save()
return vm
@classmethod
def search_expired(cls) -> List["Repro"]:
# unlike jobs/tasks, the entry is deleted from the backing table upon stop
time_filter = "end_time lt datetime'%s'" % datetime.utcnow().isoformat()
return cls.search(raw_unchecked_filter=time_filter)
@classmethod
def key_fields(cls) -> Tuple[str, Optional[str]]:
return ("vm_id", None)

View File

@ -1,118 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import json
import logging
from typing import TYPE_CHECKING, Sequence, Type, TypeVar, Union
from uuid import UUID
from azure.functions import HttpRequest, HttpResponse
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error
from onefuzztypes.responses import BaseResponse
from pydantic import BaseModel # noqa: F401
from pydantic import ValidationError
from .orm import ModelMixin
# We don't actually use these types at runtime at this time. Rather,
# these are used in a bound TypeVar. MyPy suggests to only import these
# types during type checking.
if TYPE_CHECKING:
from onefuzztypes.requests import BaseRequest # noqa: F401
def ok(
data: Union[BaseResponse, Sequence[BaseResponse], ModelMixin, Sequence[ModelMixin]]
) -> HttpResponse:
if isinstance(data, BaseResponse):
return HttpResponse(data.json(exclude_none=True), mimetype="application/json")
if isinstance(data, list) and len(data) > 0 and isinstance(data[0], BaseResponse):
decoded = [json.loads(x.json(exclude_none=True)) for x in data]
return HttpResponse(json.dumps(decoded), mimetype="application/json")
if isinstance(data, ModelMixin):
return HttpResponse(
data.json(exclude_none=True, exclude=data.export_exclude()),
mimetype="application/json",
)
decoded = [
x.raw(exclude_none=True, exclude=x.export_exclude())
if isinstance(x, ModelMixin)
else x
for x in data
]
return HttpResponse(
json.dumps(decoded),
mimetype="application/json",
)
def not_ok(
error: Error, *, status_code: int = 400, context: Union[str, UUID]
) -> HttpResponse:
if 400 <= status_code <= 599:
logging.error("request error - %s: %s", str(context), error.json())
return HttpResponse(
error.json(), status_code=status_code, mimetype="application/json"
)
else:
raise Exception(
"status code %s is not int the expected range [400; 599]" % status_code
)
def redirect(url: str) -> HttpResponse:
return HttpResponse(status_code=302, headers={"Location": url})
def convert_error(err: ValidationError) -> Error:
errors = []
for error in err.errors():
if isinstance(error["loc"], tuple):
name = ".".join([str(x) for x in error["loc"]])
else:
name = str(error["loc"])
errors.append("%s: %s" % (name, error["msg"]))
return Error(code=ErrorCode.INVALID_REQUEST, errors=errors)
# TODO: loosen restrictions here during dev. We should be specific
# about only parsing things that are of a "Request" type, but there are
# a handful of types that need work in order to enforce that.
#
# These can be easily found by swapping the following comment and running
# mypy.
#
# A = TypeVar("A", bound="BaseRequest")
A = TypeVar("A", bound="BaseModel")
def parse_request(cls: Type[A], req: HttpRequest) -> Union[A, Error]:
try:
return cls.parse_obj(req.get_json())
except ValidationError as err:
return convert_error(err)
def parse_uri(cls: Type[A], req: HttpRequest) -> Union[A, Error]:
data = {}
for key in req.params:
data[key] = req.params[key]
try:
return cls.parse_obj(data)
except ValidationError as err:
return convert_error(err)
class RequestException(Exception):
def __init__(self, error: Error):
self.error = error
message = "error %s" % error
super().__init__(message)

View File

@ -1,109 +0,0 @@
from typing import Dict, List, Optional
from uuid import UUID
from onefuzztypes.models import ApiAccessRule
class RuleConflictError(Exception):
def __init__(self, message: str) -> None:
super(RuleConflictError, self).__init__(message)
self.message = message
class RequestAccess:
"""
Stores the rules associated with a the request paths
"""
class Rules:
allowed_groups_ids: List[UUID]
def __init__(self, allowed_groups_ids: List[UUID] = []) -> None:
self.allowed_groups_ids = allowed_groups_ids
class Node:
# http method -> rules
rules: Dict[str, "RequestAccess.Rules"]
# path -> node
children: Dict[str, "RequestAccess.Node"]
def __init__(self) -> None:
self.rules = {}
self.children = {}
pass
root: Node
def __init__(self) -> None:
self.root = RequestAccess.Node()
def __add_url__(self, methods: List[str], path: str, rules: Rules) -> None:
methods = list(map(lambda m: m.upper(), methods))
segments = [s for s in path.split("/") if s != ""]
if len(segments) == 0:
return
current_node = self.root
current_segment_index = 0
while current_segment_index < len(segments):
current_segment = segments[current_segment_index]
if current_segment in current_node.children:
current_node = current_node.children[current_segment]
current_segment_index = current_segment_index + 1
else:
break
# we found a node matching this exact path
# This means that there is an existing rule causing a conflict
if current_segment_index == len(segments):
for method in methods:
if method in current_node.rules:
raise RuleConflictError(f"Conflicting rules on {method} {path}")
while current_segment_index < len(segments):
current_segment = segments[current_segment_index]
current_node.children[current_segment] = RequestAccess.Node()
current_node = current_node.children[current_segment]
current_segment_index = current_segment_index + 1
for method in methods:
current_node.rules[method] = rules
def get_matching_rules(self, method: str, path: str) -> Optional[Rules]:
method = method.upper()
segments = [s for s in path.split("/") if s != ""]
current_node = self.root
current_rule = None
if method in current_node.rules:
current_rule = current_node.rules[method]
current_segment_index = 0
while current_segment_index < len(segments):
current_segment = segments[current_segment_index]
if current_segment in current_node.children:
current_node = current_node.children[current_segment]
elif "*" in current_node.children:
current_node = current_node.children["*"]
else:
break
if method in current_node.rules:
current_rule = current_node.rules[method]
current_segment_index = current_segment_index + 1
return current_rule
@classmethod
def build(cls, rules: Dict[str, ApiAccessRule]) -> "RequestAccess":
request_access = RequestAccess()
for endpoint in rules:
rule = rules[endpoint]
request_access.__add_url__(
rule.methods,
endpoint,
RequestAccess.Rules(allowed_groups_ids=rule.allowed_groups),
)
return request_access

View File

@ -1,87 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from typing import Tuple, Type, TypeVar
from urllib.parse import urlparse
from uuid import uuid4
from azure.keyvault.secrets import KeyVaultSecret
from onefuzztypes.models import SecretAddress, SecretData
from pydantic import BaseModel
from .azure.creds import get_keyvault_client
A = TypeVar("A", bound=BaseModel)
def save_to_keyvault(secret_data: SecretData) -> SecretData:
if isinstance(secret_data.secret, SecretAddress):
return secret_data
secret_name = str(uuid4())
if isinstance(secret_data.secret, str):
secret_value = secret_data.secret
elif isinstance(secret_data.secret, BaseModel):
secret_value = secret_data.secret.json()
else:
raise Exception("invalid secret data")
kv = store_in_keyvault(get_keyvault_address(), secret_name, secret_value)
secret_data.secret = SecretAddress(url=kv.id)
return secret_data
def get_secret_string_value(self: SecretData[str]) -> str:
if isinstance(self.secret, SecretAddress):
secret = get_secret(self.secret.url)
return secret.value
else:
return self.secret
def get_keyvault_address() -> str:
# https://docs.microsoft.com/en-us/azure/key-vault/general/about-keys-secrets-certificates#vault-name-and-object-name
keyvault_name = os.environ["ONEFUZZ_KEYVAULT"]
return f"https://{keyvault_name}.vault.azure.net"
def store_in_keyvault(
keyvault_url: str, secret_name: str, secret_value: str
) -> KeyVaultSecret:
keyvault_client = get_keyvault_client(keyvault_url)
kvs: KeyVaultSecret = keyvault_client.set_secret(secret_name, secret_value)
return kvs
def parse_secret_url(secret_url: str) -> Tuple[str, str]:
# format: https://{vault-name}.vault.azure.net/secrets/{secret-name}/{version}
u = urlparse(secret_url)
vault_url = f"{u.scheme}://{u.netloc}"
secret_name = u.path.split("/")[2]
return (vault_url, secret_name)
def get_secret(secret_url: str) -> KeyVaultSecret:
(vault_url, secret_name) = parse_secret_url(secret_url)
keyvault_client = get_keyvault_client(vault_url)
return keyvault_client.get_secret(secret_name)
def get_secret_obj(secret_url: str, model: Type[A]) -> A:
secret = get_secret(secret_url)
return model.parse_raw(secret.value)
def delete_secret(secret_url: str) -> None:
(vault_url, secret_name) = parse_secret_url(secret_url)
keyvault_client = get_keyvault_client(vault_url)
keyvault_client.begin_delete_secret(secret_name)
def delete_remote_secret_data(data: SecretData) -> None:
if isinstance(data.secret, SecretAddress):
delete_secret(data.secret.url)

View File

@ -1,52 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import datetime
from typing import List, Optional, Tuple
from uuid import UUID
from onefuzztypes.models import TaskEvent as BASE_TASK_EVENT
from onefuzztypes.models import TaskEventSummary, WorkerEvent
from .orm import ORMMixin
class TaskEvent(BASE_TASK_EVENT, ORMMixin):
@classmethod
def get_summary(cls, task_id: UUID) -> List[TaskEventSummary]:
events = cls.search(query={"task_id": [task_id]})
# handle None case of Optional[e.timestamp], which shouldn't happen
events.sort(key=lambda e: e.timestamp or datetime.datetime.max)
return [
TaskEventSummary(
timestamp=e.timestamp,
event_data=get_event_data(e.event_data),
event_type=get_event_type(e.event_data),
)
for e in events
]
@classmethod
def key_fields(cls) -> Tuple[str, Optional[str]]:
return ("task_id", None)
def get_event_data(event: WorkerEvent) -> str:
if event.done:
return "exit status: %s" % event.done.exit_status
elif event.running:
return ""
else:
return "Unrecognized event: %s" % event
def get_event_type(event: WorkerEvent) -> str:
if event.done:
return type(event.done).__name__
elif event.running:
return type(event.running).__name__
else:
return "Unrecognized event: %s" % event

View File

@ -1,466 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
import pathlib
from typing import Dict, List, Optional
from onefuzztypes.enums import Compare, ContainerPermission, ContainerType, TaskFeature
from onefuzztypes.models import Job, Task, TaskConfig, TaskDefinition, TaskUnitConfig
from onefuzztypes.primitives import Container
from ..azure.containers import (
add_container_sas_url,
blob_exists,
container_exists,
get_container_sas_url,
)
from ..azure.creds import get_instance_id
from ..azure.queue import get_queue_sas
from ..azure.storage import StorageType
from ..workers.pools import Pool
from .defs import TASK_DEFINITIONS
LOGGER = logging.getLogger("onefuzz")
def get_input_container_queues(config: TaskConfig) -> Optional[List[str]]: # tasks.Task
if config.task.type not in TASK_DEFINITIONS:
raise TaskConfigError("unsupported task type: %s" % config.task.type.name)
container_type = TASK_DEFINITIONS[config.task.type].monitor_queue
if container_type:
return [x.name for x in config.containers if x.type == container_type]
return None
def check_val(compare: Compare, expected: int, actual: int) -> bool:
if compare == Compare.Equal:
return expected == actual
if compare == Compare.AtLeast:
return expected <= actual
if compare == Compare.AtMost:
return expected >= actual
raise NotImplementedError
def check_container(
compare: Compare,
expected: int,
container_type: ContainerType,
containers: Dict[ContainerType, List[Container]],
) -> None:
actual = len(containers.get(container_type, []))
if not check_val(compare, expected, actual):
raise TaskConfigError(
"container type %s: expected %s %d, got %d"
% (container_type.name, compare.name, expected, actual)
)
def check_containers(definition: TaskDefinition, config: TaskConfig) -> None:
checked = set()
containers: Dict[ContainerType, List[Container]] = {}
for container in config.containers:
if container.name not in checked:
if not container_exists(container.name, StorageType.corpus):
raise TaskConfigError("missing container: %s" % container.name)
checked.add(container.name)
if container.type not in containers:
containers[container.type] = []
containers[container.type].append(container.name)
for container_def in definition.containers:
check_container(
container_def.compare, container_def.value, container_def.type, containers
)
for container_type in containers:
if container_type not in [x.type for x in definition.containers]:
raise TaskConfigError(
"unsupported container type for this task: %s" % container_type.name
)
if definition.monitor_queue:
if definition.monitor_queue not in [x.type for x in definition.containers]:
raise TaskConfigError(
"unable to monitor container type as it is not used by this task: %s"
% definition.monitor_queue.name
)
def check_target_exe(config: TaskConfig, definition: TaskDefinition) -> None:
if config.task.target_exe is None:
if TaskFeature.target_exe in definition.features:
raise TaskConfigError("missing target_exe")
if TaskFeature.target_exe_optional in definition.features:
return
return
# User-submitted paths must be relative to the setup directory that contains them.
# They also must be normalized, and exclude special filesystem path elements.
#
# For example, accessing the blob store path "./foo" generates an exception, but
# "foo" and "foo/bar" do not.
if not is_valid_blob_name(config.task.target_exe):
raise TaskConfigError("target_exe must be a canonicalized relative path")
container = [x for x in config.containers if x.type == ContainerType.setup][0]
if not blob_exists(container.name, config.task.target_exe, StorageType.corpus):
err = "target_exe `%s` does not exist in the setup container `%s`" % (
config.task.target_exe,
container.name,
)
LOGGER.warning(err)
# Azure Blob Storage uses a flat scheme, and has no true directory hierarchy. Forward
# slashes are used to delimit a _virtual_ directory structure.
def is_valid_blob_name(blob_name: str) -> bool:
# https://docs.microsoft.com/en-us/rest/api/storageservices/naming-and-referencing-containers--blobs--and-metadata#blob-names
MIN_LENGTH = 1
MAX_LENGTH = 1024 # Inclusive
MAX_PATH_SEGMENTS = 254
length = len(blob_name)
# No leading/trailing whitespace.
if blob_name != blob_name.strip():
return False
if length < MIN_LENGTH:
return False
if length > MAX_LENGTH:
return False
path = pathlib.PurePosixPath(blob_name)
if len(path.parts) > MAX_PATH_SEGMENTS:
return False
# No path segment should end with a dot (`.`).
for part in path.parts:
if part.endswith("."):
return False
# Reject absolute paths to avoid confusion.
if path.is_absolute():
return False
# Reject paths with special relative filesystem entries.
if "." in path.parts:
return False
if ".." in path.parts:
return False
# Will not have a leading `.`, even if `blob_name` does.
normalized = path.as_posix()
return blob_name == normalized
def target_uses_input(config: TaskConfig) -> bool:
if config.task.target_options is not None:
for option in config.task.target_options:
if "{input}" in option:
return True
if config.task.target_env is not None:
for value in config.task.target_env.values():
if "{input}" in value:
return True
return False
def check_config(config: TaskConfig) -> None:
if config.task.type not in TASK_DEFINITIONS:
raise TaskConfigError("unsupported task type: %s" % config.task.type.name)
if config.vm is not None and config.pool is not None:
raise TaskConfigError("either the vm or pool must be specified, but not both")
definition = TASK_DEFINITIONS[config.task.type]
check_containers(definition, config)
if (
TaskFeature.supervisor_exe in definition.features
and not config.task.supervisor_exe
):
err = "missing supervisor_exe"
LOGGER.error(err)
raise TaskConfigError("missing supervisor_exe")
if (
TaskFeature.target_must_use_input in definition.features
and not target_uses_input(config)
):
raise TaskConfigError("{input} must be used in target_env or target_options")
if config.vm:
err = "specifying task config vm is no longer supported"
raise TaskConfigError(err)
if not config.pool:
raise TaskConfigError("pool must be specified")
if not check_val(definition.vm.compare, definition.vm.value, config.pool.count):
err = "invalid vm count: expected %s %d, got %s" % (
definition.vm.compare,
definition.vm.value,
config.pool.count,
)
LOGGER.error(err)
raise TaskConfigError(err)
pool = Pool.get_by_name(config.pool.pool_name)
if not isinstance(pool, Pool):
raise TaskConfigError(f"invalid pool: {config.pool.pool_name}")
check_target_exe(config, definition)
if TaskFeature.generator_exe in definition.features:
container = [x for x in config.containers if x.type == ContainerType.tools][0]
if not config.task.generator_exe:
raise TaskConfigError("generator_exe is not defined")
tools_paths = ["{tools_dir}/", "{tools_dir}\\"]
for tool_path in tools_paths:
if config.task.generator_exe.startswith(tool_path):
generator = config.task.generator_exe.replace(tool_path, "")
if not blob_exists(container.name, generator, StorageType.corpus):
err = (
"generator_exe `%s` does not exist in the tools container `%s`"
% (
config.task.generator_exe,
container.name,
)
)
LOGGER.error(err)
raise TaskConfigError(err)
if TaskFeature.stats_file in definition.features:
if config.task.stats_file is not None and config.task.stats_format is None:
err = "using a stats_file requires a stats_format"
LOGGER.error(err)
raise TaskConfigError(err)
def build_task_config(job: Job, task: Task) -> TaskUnitConfig:
job_id = job.job_id
task_id = task.task_id
task_config = task.config
if task_config.task.type not in TASK_DEFINITIONS:
raise TaskConfigError("unsupported task type: %s" % task_config.task.type.name)
definition = TASK_DEFINITIONS[task_config.task.type]
config = TaskUnitConfig(
job_id=job_id,
task_id=task_id,
task_type=task_config.task.type,
instance_telemetry_key=os.environ.get("APPINSIGHTS_INSTRUMENTATIONKEY"),
microsoft_telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"),
heartbeat_queue=get_queue_sas(
"task-heartbeat",
StorageType.config,
add=True,
),
instance_id=get_instance_id(),
)
if job.config.logs:
config.logs = add_container_sas_url(job.config.logs)
else:
LOGGER.warning("Missing log container: job_id %s, task_id %s", job_id, task_id)
if definition.monitor_queue:
config.input_queue = get_queue_sas(
task_id,
StorageType.corpus,
add=True,
read=True,
update=True,
process=True,
)
for container_def in definition.containers:
if container_def.type == ContainerType.setup:
continue
containers = []
for (i, container) in enumerate(task_config.containers):
if container.type != container_def.type:
continue
containers.append(
{
"path": "_".join(["task", container_def.type.name, str(i)]),
"url": get_container_sas_url(
container.name,
StorageType.corpus,
read=ContainerPermission.Read in container_def.permissions,
write=ContainerPermission.Write in container_def.permissions,
delete=ContainerPermission.Delete in container_def.permissions,
list_=ContainerPermission.List in container_def.permissions,
),
}
)
if not containers:
continue
if (
container_def.compare in [Compare.Equal, Compare.AtMost]
and container_def.value == 1
):
setattr(config, container_def.type.name, containers[0])
else:
setattr(config, container_def.type.name, containers)
EMPTY_DICT: Dict[str, str] = {}
EMPTY_LIST: List[str] = []
if TaskFeature.supervisor_exe in definition.features:
config.supervisor_exe = task_config.task.supervisor_exe
if TaskFeature.supervisor_env in definition.features:
config.supervisor_env = task_config.task.supervisor_env or EMPTY_DICT
if TaskFeature.supervisor_options in definition.features:
config.supervisor_options = task_config.task.supervisor_options or EMPTY_LIST
if TaskFeature.supervisor_input_marker in definition.features:
config.supervisor_input_marker = task_config.task.supervisor_input_marker
if TaskFeature.target_exe in definition.features:
config.target_exe = task_config.task.target_exe
if (
TaskFeature.target_exe_optional in definition.features
and task_config.task.target_exe
):
config.target_exe = task_config.task.target_exe
if TaskFeature.target_env in definition.features:
config.target_env = task_config.task.target_env or EMPTY_DICT
if TaskFeature.target_options in definition.features:
config.target_options = task_config.task.target_options or EMPTY_LIST
if TaskFeature.target_options_merge in definition.features:
config.target_options_merge = task_config.task.target_options_merge or False
if TaskFeature.target_workers in definition.features:
config.target_workers = task_config.task.target_workers
if TaskFeature.rename_output in definition.features:
config.rename_output = task_config.task.rename_output or False
if TaskFeature.generator_exe in definition.features:
config.generator_exe = task_config.task.generator_exe
if TaskFeature.generator_env in definition.features:
config.generator_env = task_config.task.generator_env or EMPTY_DICT
if TaskFeature.generator_options in definition.features:
config.generator_options = task_config.task.generator_options or EMPTY_LIST
if (
TaskFeature.wait_for_files in definition.features
and task_config.task.wait_for_files
):
config.wait_for_files = task_config.task.wait_for_files.name
if TaskFeature.analyzer_exe in definition.features:
config.analyzer_exe = task_config.task.analyzer_exe
if TaskFeature.analyzer_options in definition.features:
config.analyzer_options = task_config.task.analyzer_options or EMPTY_LIST
if TaskFeature.analyzer_env in definition.features:
config.analyzer_env = task_config.task.analyzer_env or EMPTY_DICT
if TaskFeature.stats_file in definition.features:
config.stats_file = task_config.task.stats_file
config.stats_format = task_config.task.stats_format
if TaskFeature.target_timeout in definition.features:
config.target_timeout = task_config.task.target_timeout
if TaskFeature.check_asan_log in definition.features:
config.check_asan_log = task_config.task.check_asan_log
if TaskFeature.check_debugger in definition.features:
config.check_debugger = task_config.task.check_debugger
if TaskFeature.check_retry_count in definition.features:
config.check_retry_count = task_config.task.check_retry_count or 0
if TaskFeature.ensemble_sync_delay in definition.features:
config.ensemble_sync_delay = task_config.task.ensemble_sync_delay
if TaskFeature.check_fuzzer_help in definition.features:
config.check_fuzzer_help = (
task_config.task.check_fuzzer_help
if task_config.task.check_fuzzer_help is not None
else True
)
if TaskFeature.report_list in definition.features:
config.report_list = task_config.task.report_list
if TaskFeature.minimized_stack_depth in definition.features:
config.minimized_stack_depth = task_config.task.minimized_stack_depth
if TaskFeature.expect_crash_on_failure in definition.features:
config.expect_crash_on_failure = (
task_config.task.expect_crash_on_failure
if task_config.task.expect_crash_on_failure is not None
else True
)
if TaskFeature.coverage_filter in definition.features:
coverage_filter = task_config.task.coverage_filter
if coverage_filter is not None:
config.coverage_filter = coverage_filter
if TaskFeature.target_assembly in definition.features:
config.target_assembly = task_config.task.target_assembly
if TaskFeature.target_class in definition.features:
config.target_class = task_config.task.target_class
if TaskFeature.target_method in definition.features:
config.target_method = task_config.task.target_method
return config
def get_setup_container(config: TaskConfig) -> Container:
for container in config.containers:
if container.type == ContainerType.setup:
return container.name
raise TaskConfigError(
"task missing setup container: task_type = %s" % config.task.type
)
class TaskConfigError(Exception):
pass

View File

@ -1,707 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from onefuzztypes.enums import (
Compare,
ContainerPermission,
ContainerType,
TaskFeature,
TaskType,
)
from onefuzztypes.models import ContainerDefinition, TaskDefinition, VmDefinition
# all tasks are required to have a 'setup' container
TASK_DEFINITIONS = {
TaskType.coverage: TaskDefinition(
features=[
TaskFeature.target_exe,
TaskFeature.target_env,
TaskFeature.target_options,
TaskFeature.target_timeout,
TaskFeature.coverage_filter,
TaskFeature.target_must_use_input,
],
vm=VmDefinition(compare=Compare.Equal, value=1),
containers=[
ContainerDefinition(
type=ContainerType.setup,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.readonly_inputs,
compare=Compare.AtLeast,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.coverage,
compare=Compare.Equal,
value=1,
permissions=[
ContainerPermission.List,
ContainerPermission.Read,
ContainerPermission.Write,
],
),
],
monitor_queue=ContainerType.readonly_inputs,
),
TaskType.dotnet_coverage: TaskDefinition(
features=[
TaskFeature.target_exe,
TaskFeature.target_env,
TaskFeature.target_options,
TaskFeature.target_timeout,
TaskFeature.coverage_filter,
TaskFeature.target_must_use_input,
],
vm=VmDefinition(compare=Compare.Equal, value=1),
containers=[
ContainerDefinition(
type=ContainerType.setup,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.readonly_inputs,
compare=Compare.AtLeast,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.coverage,
compare=Compare.Equal,
value=1,
permissions=[
ContainerPermission.List,
ContainerPermission.Read,
ContainerPermission.Write,
],
),
ContainerDefinition(
type=ContainerType.tools,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
],
monitor_queue=ContainerType.readonly_inputs,
),
TaskType.dotnet_crash_report: TaskDefinition(
features=[
TaskFeature.target_exe,
TaskFeature.target_env,
TaskFeature.target_options,
TaskFeature.target_timeout,
TaskFeature.check_asan_log,
TaskFeature.check_debugger,
TaskFeature.check_retry_count,
TaskFeature.minimized_stack_depth,
],
vm=VmDefinition(compare=Compare.AtLeast, value=1),
containers=[
ContainerDefinition(
type=ContainerType.setup,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.crashes,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.reports,
compare=Compare.AtMost,
value=1,
permissions=[ContainerPermission.Write],
),
ContainerDefinition(
type=ContainerType.unique_reports,
compare=Compare.AtMost,
value=1,
permissions=[ContainerPermission.Write],
),
ContainerDefinition(
type=ContainerType.no_repro,
compare=Compare.AtMost,
value=1,
permissions=[ContainerPermission.Write],
),
ContainerDefinition(
type=ContainerType.tools,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
],
monitor_queue=ContainerType.crashes,
),
TaskType.libfuzzer_dotnet_fuzz: TaskDefinition(
features=[
TaskFeature.target_exe,
TaskFeature.target_env,
TaskFeature.target_options,
TaskFeature.target_workers,
TaskFeature.ensemble_sync_delay,
TaskFeature.check_fuzzer_help,
TaskFeature.expect_crash_on_failure,
TaskFeature.target_assembly,
TaskFeature.target_class,
TaskFeature.target_method,
],
vm=VmDefinition(compare=Compare.AtLeast, value=1),
containers=[
ContainerDefinition(
type=ContainerType.setup,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.crashes,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Write],
),
ContainerDefinition(
type=ContainerType.inputs,
compare=Compare.Equal,
value=1,
permissions=[
ContainerPermission.Write,
ContainerPermission.Read,
ContainerPermission.List,
],
),
ContainerDefinition(
type=ContainerType.readonly_inputs,
compare=Compare.AtLeast,
value=0,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.tools,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
],
monitor_queue=None,
),
TaskType.generic_analysis: TaskDefinition(
features=[
TaskFeature.target_exe,
TaskFeature.target_options,
TaskFeature.analyzer_exe,
TaskFeature.analyzer_env,
TaskFeature.analyzer_options,
],
vm=VmDefinition(compare=Compare.AtLeast, value=1),
containers=[
ContainerDefinition(
type=ContainerType.setup,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.analysis,
compare=Compare.Equal,
value=1,
permissions=[
ContainerPermission.Write,
ContainerPermission.Read,
ContainerPermission.List,
],
),
ContainerDefinition(
type=ContainerType.crashes,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.tools,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
],
monitor_queue=ContainerType.crashes,
),
TaskType.libfuzzer_fuzz: TaskDefinition(
features=[
TaskFeature.target_exe,
TaskFeature.target_env,
TaskFeature.target_options,
TaskFeature.target_workers,
TaskFeature.ensemble_sync_delay,
TaskFeature.check_fuzzer_help,
TaskFeature.expect_crash_on_failure,
],
vm=VmDefinition(compare=Compare.AtLeast, value=1),
containers=[
ContainerDefinition(
type=ContainerType.setup,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.crashes,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Write],
),
ContainerDefinition(
type=ContainerType.inputs,
compare=Compare.Equal,
value=1,
permissions=[
ContainerPermission.Write,
ContainerPermission.Read,
ContainerPermission.List,
],
),
ContainerDefinition(
type=ContainerType.readonly_inputs,
compare=Compare.AtLeast,
value=0,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
],
monitor_queue=None,
),
TaskType.libfuzzer_crash_report: TaskDefinition(
features=[
TaskFeature.target_exe,
TaskFeature.target_env,
TaskFeature.target_options,
TaskFeature.target_timeout,
TaskFeature.check_retry_count,
TaskFeature.check_fuzzer_help,
TaskFeature.minimized_stack_depth,
],
vm=VmDefinition(compare=Compare.AtLeast, value=1),
containers=[
ContainerDefinition(
type=ContainerType.setup,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.crashes,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.reports,
compare=Compare.AtMost,
value=1,
permissions=[ContainerPermission.Write],
),
ContainerDefinition(
type=ContainerType.unique_reports,
compare=Compare.AtMost,
value=1,
permissions=[ContainerPermission.Write],
),
ContainerDefinition(
type=ContainerType.no_repro,
compare=Compare.AtMost,
value=1,
permissions=[ContainerPermission.Write],
),
],
monitor_queue=ContainerType.crashes,
),
TaskType.libfuzzer_merge: TaskDefinition(
features=[
TaskFeature.target_exe,
TaskFeature.target_env,
TaskFeature.target_options,
TaskFeature.check_fuzzer_help,
],
vm=VmDefinition(compare=Compare.Equal, value=1),
containers=[
ContainerDefinition(
type=ContainerType.setup,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.unique_inputs,
compare=Compare.Equal,
value=1,
permissions=[
ContainerPermission.List,
ContainerPermission.Read,
ContainerPermission.Write,
],
),
ContainerDefinition(
type=ContainerType.inputs,
compare=Compare.AtLeast,
value=0,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
],
monitor_queue=None,
),
TaskType.generic_supervisor: TaskDefinition(
features=[
TaskFeature.target_exe,
TaskFeature.target_options,
TaskFeature.supervisor_exe,
TaskFeature.supervisor_env,
TaskFeature.supervisor_options,
TaskFeature.supervisor_input_marker,
TaskFeature.wait_for_files,
TaskFeature.stats_file,
TaskFeature.ensemble_sync_delay,
],
vm=VmDefinition(compare=Compare.AtLeast, value=1),
containers=[
ContainerDefinition(
type=ContainerType.setup,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.tools,
compare=Compare.AtMost,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.crashes,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Write],
),
ContainerDefinition(
type=ContainerType.inputs,
compare=Compare.Equal,
value=1,
permissions=[
ContainerPermission.Write,
ContainerPermission.Read,
ContainerPermission.List,
],
),
ContainerDefinition(
type=ContainerType.unique_reports,
compare=Compare.AtMost,
value=1,
permissions=[
ContainerPermission.Write,
ContainerPermission.Read,
ContainerPermission.List,
],
),
ContainerDefinition(
type=ContainerType.reports,
compare=Compare.AtMost,
value=1,
permissions=[
ContainerPermission.Write,
ContainerPermission.Read,
ContainerPermission.List,
],
),
ContainerDefinition(
type=ContainerType.no_repro,
compare=Compare.AtMost,
value=1,
permissions=[
ContainerPermission.Write,
ContainerPermission.Read,
ContainerPermission.List,
],
),
ContainerDefinition(
type=ContainerType.coverage,
compare=Compare.AtMost,
value=1,
permissions=[
ContainerPermission.Write,
ContainerPermission.Read,
ContainerPermission.List,
],
),
],
monitor_queue=None,
),
TaskType.generic_merge: TaskDefinition(
features=[
TaskFeature.target_exe,
TaskFeature.target_options,
TaskFeature.supervisor_exe,
TaskFeature.supervisor_env,
TaskFeature.supervisor_options,
TaskFeature.supervisor_input_marker,
TaskFeature.stats_file,
TaskFeature.preserve_existing_outputs,
],
vm=VmDefinition(compare=Compare.AtLeast, value=1),
containers=[
ContainerDefinition(
type=ContainerType.setup,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.tools,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.readonly_inputs,
compare=Compare.AtLeast,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.inputs,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Write, ContainerPermission.List],
),
],
monitor_queue=None,
),
TaskType.generic_generator: TaskDefinition(
features=[
TaskFeature.generator_exe,
TaskFeature.generator_env,
TaskFeature.generator_options,
TaskFeature.target_exe,
TaskFeature.target_env,
TaskFeature.target_options,
TaskFeature.rename_output,
TaskFeature.target_timeout,
TaskFeature.check_asan_log,
TaskFeature.check_debugger,
TaskFeature.check_retry_count,
TaskFeature.ensemble_sync_delay,
TaskFeature.target_must_use_input,
],
vm=VmDefinition(compare=Compare.AtLeast, value=1),
containers=[
ContainerDefinition(
type=ContainerType.setup,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.tools,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.crashes,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Write],
),
ContainerDefinition(
type=ContainerType.readonly_inputs,
compare=Compare.AtLeast,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
],
monitor_queue=None,
),
TaskType.generic_crash_report: TaskDefinition(
features=[
TaskFeature.target_exe,
TaskFeature.target_env,
TaskFeature.target_options,
TaskFeature.target_timeout,
TaskFeature.check_asan_log,
TaskFeature.check_debugger,
TaskFeature.check_retry_count,
TaskFeature.minimized_stack_depth,
],
vm=VmDefinition(compare=Compare.AtLeast, value=1),
containers=[
ContainerDefinition(
type=ContainerType.setup,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.crashes,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.reports,
compare=Compare.AtMost,
value=1,
permissions=[ContainerPermission.Write],
),
ContainerDefinition(
type=ContainerType.unique_reports,
compare=Compare.AtMost,
value=1,
permissions=[ContainerPermission.Write],
),
ContainerDefinition(
type=ContainerType.no_repro,
compare=Compare.AtMost,
value=1,
permissions=[ContainerPermission.Write],
),
],
monitor_queue=ContainerType.crashes,
),
TaskType.generic_regression: TaskDefinition(
features=[
TaskFeature.target_exe,
TaskFeature.target_env,
TaskFeature.target_options,
TaskFeature.target_timeout,
TaskFeature.check_asan_log,
TaskFeature.check_debugger,
TaskFeature.check_retry_count,
TaskFeature.report_list,
TaskFeature.minimized_stack_depth,
],
vm=VmDefinition(compare=Compare.AtLeast, value=1),
containers=[
ContainerDefinition(
type=ContainerType.setup,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.regression_reports,
compare=Compare.Equal,
value=1,
permissions=[
ContainerPermission.Write,
ContainerPermission.Read,
ContainerPermission.List,
],
),
ContainerDefinition(
type=ContainerType.crashes,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.reports,
compare=Compare.AtMost,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.unique_reports,
compare=Compare.AtMost,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.no_repro,
compare=Compare.AtMost,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.readonly_inputs,
compare=Compare.AtMost,
value=1,
permissions=[
ContainerPermission.Read,
ContainerPermission.List,
],
),
],
),
TaskType.libfuzzer_regression: TaskDefinition(
features=[
TaskFeature.target_exe,
TaskFeature.target_env,
TaskFeature.target_options,
TaskFeature.target_timeout,
TaskFeature.check_fuzzer_help,
TaskFeature.check_retry_count,
TaskFeature.report_list,
TaskFeature.minimized_stack_depth,
],
vm=VmDefinition(compare=Compare.AtLeast, value=1),
containers=[
ContainerDefinition(
type=ContainerType.setup,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.regression_reports,
compare=Compare.Equal,
value=1,
permissions=[
ContainerPermission.Write,
ContainerPermission.Read,
ContainerPermission.List,
],
),
ContainerDefinition(
type=ContainerType.crashes,
compare=Compare.Equal,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.unique_reports,
compare=Compare.AtMost,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.reports,
compare=Compare.AtMost,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.no_repro,
compare=Compare.AtMost,
value=1,
permissions=[ContainerPermission.Read, ContainerPermission.List],
),
ContainerDefinition(
type=ContainerType.readonly_inputs,
compare=Compare.AtMost,
value=1,
permissions=[
ContainerPermission.Read,
ContainerPermission.List,
],
),
],
),
}

View File

@ -1,360 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from datetime import datetime, timedelta
from typing import List, Optional, Tuple, Union
from uuid import UUID
from onefuzztypes.enums import ErrorCode, TaskState
from onefuzztypes.events import (
EventTaskCreated,
EventTaskFailed,
EventTaskStateUpdated,
EventTaskStopped,
)
from onefuzztypes.models import Error
from onefuzztypes.models import Task as BASE_TASK
from onefuzztypes.models import TaskConfig, TaskVm, UserInfo
from onefuzztypes.primitives import PoolName
from ..azure.image import get_os
from ..azure.queue import create_queue, delete_queue
from ..azure.storage import StorageType
from ..events import send_event
from ..orm import MappingIntStrAny, ORMMixin, QueryFilter
from ..workers.nodes import Node, NodeTasks
from ..workers.pools import Pool
from ..workers.scalesets import Scaleset
class Task(BASE_TASK, ORMMixin):
def check_prereq_tasks(self) -> bool:
if self.config.prereq_tasks:
for task_id in self.config.prereq_tasks:
task = Task.get_by_task_id(task_id)
# if a prereq task fails, then mark this task as failed
if isinstance(task, Error):
self.mark_failed(task)
return False
if task.state not in task.state.has_started():
return False
return True
@classmethod
def create(
cls, config: TaskConfig, job_id: UUID, user_info: UserInfo
) -> Union["Task", Error]:
if config.vm:
os = get_os(config.vm.region, config.vm.image)
if isinstance(os, Error):
return os
elif config.pool:
pool = Pool.get_by_name(config.pool.pool_name)
if isinstance(pool, Error):
return pool
os = pool.os
else:
raise Exception("task must have vm or pool")
task = cls(config=config, job_id=job_id, os=os, user_info=user_info)
task.save()
send_event(
EventTaskCreated(
job_id=task.job_id,
task_id=task.task_id,
config=config,
user_info=user_info,
)
)
logging.info(
"created task. job_id:%s task_id:%s type:%s",
task.job_id,
task.task_id,
task.config.task.type.name,
)
return task
def save_exclude(self) -> Optional[MappingIntStrAny]:
return {"heartbeats": ...}
def is_ready(self) -> bool:
if self.config.prereq_tasks:
for prereq_id in self.config.prereq_tasks:
prereq = Task.get_by_task_id(prereq_id)
if isinstance(prereq, Error):
logging.info("task prereq has error: %s - %s", self.task_id, prereq)
self.mark_failed(prereq)
return False
if prereq.state != TaskState.running:
logging.info(
"task is waiting on prereq: %s - %s:",
self.task_id,
prereq.task_id,
)
return False
return True
# At current, the telemetry filter will generate something like this:
#
# {
# 'job_id': 'f4a20fd8-0dcc-4a4e-8804-6ef7df50c978',
# 'task_id': '835f7b3f-43ad-4718-b7e4-d506d9667b09',
# 'state': 'stopped',
# 'config': {
# 'task': {'type': 'coverage'},
# 'vm': {'count': 1}
# }
# }
def telemetry_include(self) -> Optional[MappingIntStrAny]:
return {
"job_id": ...,
"task_id": ...,
"state": ...,
"config": {"vm": {"count": ...}, "task": {"type": ...}},
}
def init(self) -> None:
create_queue(self.task_id, StorageType.corpus)
self.set_state(TaskState.waiting)
def stopping(self) -> None:
logging.info("stopping task: %s:%s", self.job_id, self.task_id)
Node.stop_task(self.task_id)
if not NodeTasks.get_nodes_by_task_id(self.task_id):
self.stopped()
def stopped(self) -> None:
self.set_state(TaskState.stopped)
delete_queue(str(self.task_id), StorageType.corpus)
# TODO: we need to 'unschedule' this task from the existing pools
from ..jobs import Job
job = Job.get(self.job_id)
if job:
job.stop_if_all_done()
@classmethod
def search_states(
cls, *, job_id: Optional[UUID] = None, states: Optional[List[TaskState]] = None
) -> List["Task"]:
query: QueryFilter = {}
if job_id:
query["job_id"] = [job_id]
if states:
query["state"] = states
return cls.search(query=query)
@classmethod
def search_expired(cls) -> List["Task"]:
time_filter = "end_time lt datetime'%s'" % datetime.utcnow().isoformat()
return cls.search(
query={"state": TaskState.available()}, raw_unchecked_filter=time_filter
)
@classmethod
def get_by_task_id(cls, task_id: UUID) -> Union[Error, "Task"]:
tasks = cls.search(query={"task_id": [task_id]})
if not tasks:
return Error(code=ErrorCode.INVALID_REQUEST, errors=["unable to find task"])
if len(tasks) != 1:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["error identifying task"]
)
task = tasks[0]
return task
@classmethod
def get_tasks_by_pool_name(cls, pool_name: PoolName) -> List["Task"]:
tasks = cls.search_states(states=TaskState.available())
if not tasks:
return []
pool_tasks = []
for task in tasks:
task_pool = task.get_pool()
if not task_pool:
continue
if pool_name == task_pool.name:
pool_tasks.append(task)
return pool_tasks
def mark_stopping(self) -> None:
if self.state in TaskState.shutting_down():
logging.debug(
"ignoring post-task stop calls to stop %s:%s", self.job_id, self.task_id
)
return
if self.state not in TaskState.has_started():
self.mark_failed(
Error(code=ErrorCode.TASK_FAILED, errors=["task never started"])
)
self.set_state(TaskState.stopping)
def mark_failed(
self, error: Error, tasks_in_job: Optional[List["Task"]] = None
) -> None:
if self.state in TaskState.shutting_down():
logging.debug(
"ignoring post-task stop failures for %s:%s", self.job_id, self.task_id
)
return
if self.error is not None:
logging.debug(
"ignoring additional task error %s:%s", self.job_id, self.task_id
)
return
logging.error("task failed %s:%s - %s", self.job_id, self.task_id, error)
self.error = error
self.set_state(TaskState.stopping)
self.mark_dependants_failed(tasks_in_job=tasks_in_job)
def mark_dependants_failed(
self, tasks_in_job: Optional[List["Task"]] = None
) -> None:
if tasks_in_job is None:
tasks_in_job = Task.search(query={"job_id": [self.job_id]})
for task in tasks_in_job:
if task.config.prereq_tasks:
if self.task_id in task.config.prereq_tasks:
task.mark_failed(
Error(
code=ErrorCode.TASK_FAILED,
errors=[
"prerequisite task failed. task_id:%s" % self.task_id
],
),
tasks_in_job,
)
def get_pool(self) -> Optional[Pool]:
if self.config.pool:
pool = Pool.get_by_name(self.config.pool.pool_name)
if isinstance(pool, Error):
logging.info(
"unable to schedule task to pool: %s - %s", self.task_id, pool
)
return None
return pool
elif self.config.vm:
scalesets = Scaleset.search()
scalesets = [
x
for x in scalesets
if x.vm_sku == self.config.vm.sku and x.image == self.config.vm.image
]
for scaleset in scalesets:
pool = Pool.get_by_name(scaleset.pool_name)
if isinstance(pool, Error):
logging.info(
"unable to schedule task to pool: %s - %s",
self.task_id,
pool,
)
else:
return pool
logging.warning(
"unable to find a scaleset that matches the task prereqs: %s",
self.task_id,
)
return None
def get_repro_vm_config(self) -> Union[TaskVm, None]:
if self.config.vm:
return self.config.vm
if self.config.pool is None:
raise Exception("either pool or vm must be specified: %s" % self.task_id)
pool = Pool.get_by_name(self.config.pool.pool_name)
if isinstance(pool, Error):
logging.info("unable to find pool from task: %s", self.task_id)
return None
for scaleset in Scaleset.search_by_pool(self.config.pool.pool_name):
return TaskVm(
region=scaleset.region,
sku=scaleset.vm_sku,
image=scaleset.image,
)
logging.warning(
"no scalesets are defined for task: %s:%s", self.job_id, self.task_id
)
return None
def on_start(self) -> None:
# try to keep this effectively idempotent
if self.end_time is None:
self.end_time = datetime.utcnow() + timedelta(
hours=self.config.task.duration
)
from ..jobs import Job
job = Job.get(self.job_id)
if job:
job.on_start()
@classmethod
def key_fields(cls) -> Tuple[str, str]:
return ("job_id", "task_id")
def set_state(self, state: TaskState) -> None:
if self.state == state:
return
self.state = state
if self.state in [TaskState.running, TaskState.setting_up]:
self.on_start()
self.save()
if self.state == TaskState.stopped:
if self.error:
send_event(
EventTaskFailed(
job_id=self.job_id,
task_id=self.task_id,
error=self.error,
user_info=self.user_info,
config=self.config,
)
)
else:
send_event(
EventTaskStopped(
job_id=self.job_id,
task_id=self.task_id,
user_info=self.user_info,
config=self.config,
)
)
else:
send_event(
EventTaskStateUpdated(
job_id=self.job_id,
task_id=self.task_id,
state=self.state,
end_time=self.end_time,
config=self.config,
)
)

View File

@ -1,248 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from typing import Dict, Generator, List, Optional, Tuple, TypeVar
from uuid import UUID, uuid4
from onefuzztypes.enums import OS, PoolState, TaskState
from onefuzztypes.models import WorkSet, WorkUnit
from onefuzztypes.primitives import Container
from pydantic import BaseModel
from ..azure.containers import blob_exists, get_container_sas_url
from ..azure.storage import StorageType
from ..jobs import Job
from ..workers.pools import Pool
from .config import build_task_config, get_setup_container
from .main import Task
HOURS = 60 * 60
# TODO: eventually, this should be tied to the pool.
MAX_TASKS_PER_SET = 10
A = TypeVar("A")
def chunks(items: List[A], size: int) -> Generator[List[A], None, None]:
return (items[x : x + size] for x in range(0, len(items), size))
def schedule_workset(workset: WorkSet, pool: Pool, count: int) -> bool:
if pool.state not in PoolState.available():
logging.info(
"pool not available for work: %s state: %s", pool.name, pool.state.name
)
return False
for _ in range(count):
if not pool.schedule_workset(workset):
logging.error(
"unable to schedule workset. pool:%s workset:%s", pool.name, workset
)
return False
return True
# TODO - Once Pydantic supports hashable models, the Tuple should be replaced
# with a model.
#
# For info: https://github.com/samuelcolvin/pydantic/pull/1881
def bucket_tasks(tasks: List[Task]) -> Dict[Tuple, List[Task]]:
# buckets are hashed by:
# OS, JOB ID, vm sku & image (if available), pool name (if available),
# if the setup script requires rebooting, and a 'unique' value
#
# The unique value is set based on the following conditions:
# * if the task is set to run on more than one VM, than we assume it can't be shared
# * if the task is missing the 'colocate' flag or it's set to False
buckets: Dict[Tuple, List[Task]] = {}
for task in tasks:
vm: Optional[Tuple[str, str]] = None
pool: Optional[str] = None
unique: Optional[UUID] = None
# check for multiple VMs for pre-1.0.0 tasks
if task.config.vm:
vm = (task.config.vm.sku, task.config.vm.image)
if task.config.vm.count > 1:
unique = uuid4()
# check for multiple VMs for 1.0.0 and later tasks
if task.config.pool:
pool = task.config.pool.pool_name
if task.config.pool.count > 1:
unique = uuid4()
if not task.config.colocate:
unique = uuid4()
key = (
task.os,
task.job_id,
vm,
pool,
get_setup_container(task.config),
task.config.task.reboot_after_setup,
unique,
)
if key not in buckets:
buckets[key] = []
buckets[key].append(task)
return buckets
class BucketConfig(BaseModel):
count: int
reboot: bool
setup_container: Container
setup_script: Optional[str]
pool: Pool
def build_work_unit(task: Task) -> Optional[Tuple[BucketConfig, WorkUnit]]:
pool = task.get_pool()
if not pool:
logging.info("unable to find pool for task: %s", task.task_id)
return None
logging.info("scheduling task: %s", task.task_id)
job = Job.get(task.job_id)
if not job:
raise Exception(f"invalid job_id {task.job_id} for task {task.task_id}")
task_config = build_task_config(job, task)
setup_container = get_setup_container(task.config)
setup_script = None
if task.os == OS.windows and blob_exists(
setup_container, "setup.ps1", StorageType.corpus
):
setup_script = "setup.ps1"
if task.os == OS.linux and blob_exists(
setup_container, "setup.sh", StorageType.corpus
):
setup_script = "setup.sh"
reboot = False
count = 1
if task.config.pool:
count = task.config.pool.count
# NOTE: "is True" is required to handle Optional[bool]
reboot = task.config.task.reboot_after_setup is True
elif task.config.vm:
# this branch should go away when we stop letting people specify
# VM configs directly.
count = task.config.vm.count
# NOTE: "is True" is required to handle Optional[bool]
reboot = (
task.config.vm.reboot_after_setup is True
or task.config.task.reboot_after_setup is True
)
else:
raise TypeError
work_unit = WorkUnit(
job_id=task_config.job_id,
task_id=task_config.task_id,
task_type=task_config.task_type,
config=task_config.json(exclude_none=True, exclude_unset=True),
)
bucket_config = BucketConfig(
pool=pool,
count=count,
reboot=reboot,
setup_script=setup_script,
setup_container=setup_container,
)
return bucket_config, work_unit
def build_work_set(tasks: List[Task]) -> Optional[Tuple[BucketConfig, WorkSet]]:
task_ids = [x.task_id for x in tasks]
bucket_config: Optional[BucketConfig] = None
work_units = []
for task in tasks:
if task.config.prereq_tasks:
# if all of the prereqs are in this bucket, they will be
# scheduled together
if not all([task_id in task_ids for task_id in task.config.prereq_tasks]):
if not task.check_prereq_tasks():
continue
result = build_work_unit(task)
if not result:
continue
new_bucket_config, work_unit = result
if bucket_config is None:
bucket_config = new_bucket_config
else:
if bucket_config != new_bucket_config:
raise Exception(
f"bucket configs differ: {bucket_config} VS {new_bucket_config}"
)
work_units.append(work_unit)
if bucket_config:
setup_url = get_container_sas_url(
bucket_config.setup_container, StorageType.corpus, read=True, list_=True
)
work_set = WorkSet(
reboot=bucket_config.reboot,
script=(bucket_config.setup_script is not None),
setup_url=setup_url,
work_units=work_units,
)
return (bucket_config, work_set)
return None
def schedule_tasks() -> None:
tasks: List[Task] = []
tasks = Task.search_states(states=[TaskState.waiting])
tasks_by_id = {x.task_id: x for x in tasks}
seen = set()
not_ready_count = 0
buckets = bucket_tasks(tasks)
for bucketed_tasks in buckets.values():
for chunk in chunks(bucketed_tasks, MAX_TASKS_PER_SET):
result = build_work_set(chunk)
if result is None:
continue
bucket_config, work_set = result
if schedule_workset(work_set, bucket_config.pool, bucket_config.count):
for work_unit in work_set.work_units:
task = tasks_by_id[work_unit.task_id]
task.set_state(TaskState.scheduled)
seen.add(task.task_id)
not_ready_count = len(tasks) - len(seen)
if not_ready_count > 0:
logging.info("tasks not ready: %d", not_ready_count)

View File

@ -1,72 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
from typing import Any, Dict, Optional, Union
from onefuzztypes.enums import TelemetryData, TelemetryEvent
from opencensus.ext.azure.log_exporter import AzureLogHandler
LOCAL_CLIENT: Optional[logging.Logger] = None
CENTRAL_CLIENT: Optional[logging.Logger] = None
def _get_client(environ_key: str) -> Optional[logging.Logger]:
key = os.environ.get(environ_key)
if key is None:
return None
client = logging.getLogger("onefuzz")
client.addHandler(AzureLogHandler(connection_string="InstrumentationKey=%s" % key))
return client
def _central_client() -> Optional[logging.Logger]:
global CENTRAL_CLIENT
if not CENTRAL_CLIENT:
CENTRAL_CLIENT = _get_client("ONEFUZZ_TELEMETRY")
return CENTRAL_CLIENT
def _local_client() -> Union[None, Any, logging.Logger]:
global LOCAL_CLIENT
if not LOCAL_CLIENT:
LOCAL_CLIENT = _get_client("APPINSIGHTS_INSTRUMENTATIONKEY")
return LOCAL_CLIENT
# NOTE: All telemetry that is *NOT* using the ORM telemetry_include should
# go through this method
#
# This provides a point of inspection to know if it's data that is safe to
# log to the central OneFuzz telemetry point
def track_event(
event: TelemetryEvent, data: Dict[TelemetryData, Union[str, int]]
) -> None:
central = _central_client()
local = _local_client()
if local:
serialized = {k.name: v for (k, v) in data.items()}
local.info(event.name, extra={"custom_dimensions": serialized})
if event in TelemetryEvent.can_share() and central:
serialized = {
k.name: v for (k, v) in data.items() if k in TelemetryData.can_share()
}
central.info(event.name, extra={"custom_dimensions": serialized})
# NOTE: This should *only* be used for logging Telemetry data that uses
# the ORM telemetry_include method to limit data for telemetry.
def track_event_filtered(event: TelemetryEvent, data: Any) -> None:
central = _central_client()
local = _local_client()
if local:
local.info(event.name, extra={"custom_dimensions": data})
if central and event in TelemetryEvent.can_share():
central.info(event.name, extra={"custom_dimensions": data})

View File

@ -1,114 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from typing import Dict, Optional, Type
from azure.core.exceptions import ResourceNotFoundError
from msrestazure.azure_exceptions import CloudError
from onefuzztypes.enums import UpdateType
from pydantic import BaseModel
from .azure.queue import queue_object
from .azure.storage import StorageType
# This class isn't intended to be shared outside of the service
class Update(BaseModel):
update_type: UpdateType
PartitionKey: Optional[str]
RowKey: Optional[str]
method: Optional[str]
def queue_update(
update_type: UpdateType,
PartitionKey: Optional[str] = None,
RowKey: Optional[str] = None,
method: Optional[str] = None,
visibility_timeout: int = None,
) -> None:
logging.info(
"queuing type:%s id:[%s,%s] method:%s timeout: %s",
update_type.name,
PartitionKey,
RowKey,
method,
visibility_timeout,
)
update = Update(
update_type=update_type, PartitionKey=PartitionKey, RowKey=RowKey, method=method
)
try:
if not queue_object(
"update-queue",
update,
StorageType.config,
visibility_timeout=visibility_timeout,
):
logging.error("unable to queue update")
except (CloudError, ResourceNotFoundError) as err:
logging.error("GOT ERROR %s", repr(err))
logging.error("GOT ERROR %s", dir(err))
raise err
def execute_update(update: Update) -> None:
from .jobs import Job
from .orm import ORMMixin
from .proxy import Proxy
from .repro import Repro
from .tasks.main import Task
from .workers.nodes import Node
from .workers.pools import Pool
from .workers.scalesets import Scaleset
update_objects: Dict[UpdateType, Type[ORMMixin]] = {
UpdateType.Task: Task,
UpdateType.Job: Job,
UpdateType.Repro: Repro,
UpdateType.Proxy: Proxy,
UpdateType.Pool: Pool,
UpdateType.Node: Node,
UpdateType.Scaleset: Scaleset,
}
# TODO: remove these from being queued, these updates are handled elsewhere
if update.update_type == UpdateType.Scaleset:
return
if update.update_type in update_objects:
if update.PartitionKey is None or update.RowKey is None:
raise Exception("unsupported update: %s" % update)
obj = update_objects[update.update_type].get(update.PartitionKey, update.RowKey)
if not obj:
logging.error("unable find to obj to update %s", update)
return
if update.method and hasattr(obj, update.method):
logging.info("performing queued update: %s", update)
getattr(obj, update.method)()
return
else:
state = getattr(obj, "state", None)
if state is None:
logging.error("queued update for object without state: %s", update)
return
func = getattr(obj, state.name, None)
if func is None:
logging.debug(
"no function to implement state: %s - %s", update, state.name
)
return
logging.info(
"performing queued update for state: %s - %s", update, state.name
)
func()
return
raise NotImplementedError("unimplemented update type: %s" % update.update_type.name)

View File

@ -1,79 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from typing import List, Optional
from uuid import UUID
import azure.functions as func
import jwt
from memoization import cached
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error, Result, UserInfo
from .config import InstanceConfig
def get_bearer_token(request: func.HttpRequest) -> Optional[str]:
auth: str = request.headers.get("Authorization", None)
if not auth:
return None
parts = auth.split()
if len(parts) != 2:
return None
if parts[0].lower() != "bearer":
return None
return parts[1]
def get_auth_token(request: func.HttpRequest) -> Optional[str]:
token = get_bearer_token(request)
if token is not None:
return token
token_header = request.headers.get("x-ms-token-aad-id-token", None)
if token_header is None:
return None
return str(token_header)
@cached(ttl=60)
def get_allowed_tenants() -> List[str]:
config = InstanceConfig.fetch()
entries = [f"https://sts.windows.net/{x}/" for x in config.allowed_aad_tenants]
return entries
def parse_jwt_token(request: func.HttpRequest) -> Result[UserInfo]:
"""Obtains the Access Token from the Authorization Header"""
token_str = get_auth_token(request)
if token_str is None:
return Error(
code=ErrorCode.INVALID_REQUEST,
errors=["unable to find authorization token"],
)
# The JWT token has already been verified by the azure authentication layer,
# but we need to verify the tenant is as we expect.
token = jwt.decode(token_str, options={"verify_signature": False})
if "iss" not in token:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["missing issuer from token"]
)
tenants = get_allowed_tenants()
if token["iss"] not in tenants:
logging.error("issuer not from allowed tenant: %s - %s", token["iss"], tenants)
return Error(code=ErrorCode.INVALID_REQUEST, errors=["unauthorized AAD issuer"])
application_id = UUID(token["appid"]) if "appid" in token else None
object_id = UUID(token["oid"]) if "oid" in token else None
upn = token.get("upn")
return UserInfo(application_id=application_id, object_id=object_id, upn=upn)

Some files were not shown because too many files have changed in this diff Show More