refactor agent_events handler (#261)

This commit is contained in:
bmc-msft
2020-11-11 18:28:16 -05:00
committed by GitHub
parent 382003e4d0
commit ca209eb543
4 changed files with 294 additions and 216 deletions

View File

@ -4,213 +4,56 @@
# Licensed under the MIT License.
import logging
from typing import Optional, cast
from uuid import UUID
import azure.functions as func
from onefuzztypes.enums import (
ErrorCode,
NodeState,
NodeTaskState,
TaskDebugFlag,
TaskState,
)
from onefuzztypes.models import (
Error,
NodeDoneEventData,
NodeEvent,
NodeEventEnvelope,
NodeSettingUpEventData,
NodeStateUpdate,
Result,
WorkerEvent,
)
from onefuzztypes.responses import BoolResult
from ..onefuzzlib.agent_authorization import verify_token
from ..onefuzzlib.pools import Node, NodeTasks
from ..onefuzzlib.request import RequestException, not_ok, ok, parse_request
from ..onefuzzlib.task_event import TaskEvent
from ..onefuzzlib.tasks.main import Task
ERROR_CONTEXT = "node event"
from ..onefuzzlib.agent_events import on_state_update, on_worker_event
from ..onefuzzlib.request import not_ok, ok, parse_request
def get_task_checked(task_id: UUID) -> Task:
task = Task.get_by_task_id(task_id)
if isinstance(task, Error):
raise RequestException(task)
return task
def get_node_checked(machine_id: UUID) -> Node:
node = Node.get_by_machine_id(machine_id)
if not node:
err = Error(code=ErrorCode.INVALID_NODE, errors=["unable to find node"])
raise RequestException(err)
return node
def on_state_update(
machine_id: UUID,
state_update: NodeStateUpdate,
) -> None:
state = state_update.state
node = get_node_checked(machine_id)
if state == NodeState.free:
if node.reimage_requested or node.delete_requested:
logging.info("stopping free node with reset flags: %s", node.machine_id)
node.stop()
return
if node.could_shrink_scaleset():
logging.info("stopping free node to resize scaleset: %s", node.machine_id)
node.set_halt()
return
if state == NodeState.init:
if node.delete_requested:
node.stop()
return
node.reimage_requested = False
node.save()
elif node.state not in NodeState.ready_for_reset():
if node.state != state:
node.state = state
node.save()
if state == NodeState.setting_up:
# Model-validated.
#
# This field will be required in the future.
# For now, it is optional for back compat.
setting_up_data = cast(
Optional[NodeSettingUpEventData],
state_update.data,
)
if setting_up_data:
for task_id in setting_up_data.tasks:
task = get_task_checked(task_id)
# The task state may be `running` if it has `vm_count` > 1, and
# another node is concurrently executing the task. If so, leave
# the state as-is, to represent the max progress made.
#
# Other states we would want to preserve are excluded by the
# outermost conditional check.
if task.state != TaskState.running:
task.state = TaskState.setting_up
task.on_start()
task.save()
# Note: we set the node task state to `setting_up`, even though
# the task itself may be `running`.
node_task = NodeTasks(
machine_id=machine_id,
task_id=task_id,
state=NodeTaskState.setting_up,
)
node_task.save()
elif state == NodeState.done:
# if tasks are running on the node when it reports as Done
# those are stopped early
node.mark_tasks_stopped_early()
# Model-validated.
#
# This field will be required in the future.
# For now, it is optional for back compat.
done_data = cast(Optional[NodeDoneEventData], state_update.data)
if done_data:
# TODO: do something with this done data
if done_data.error:
logging.error(
"node 'done' with error: machine_id:%s, data:%s",
machine_id,
done_data,
)
else:
logging.debug("No change in Node state")
else:
logging.info("ignoring state updates from the node: %s: %s", machine_id, state)
def on_worker_event(machine_id: UUID, event: WorkerEvent) -> None:
if event.running:
task_id = event.running.task_id
elif event.done:
task_id = event.done.task_id
else:
raise NotImplementedError
task = get_task_checked(task_id)
node = get_node_checked(machine_id)
node_task = NodeTasks(
machine_id=machine_id, task_id=task_id, state=NodeTaskState.running
def process(envelope: NodeEventEnvelope) -> Result[None]:
logging.info(
"node event: machine_id: %s event: %s",
envelope.machine_id,
envelope.event,
)
if event.running:
if task.state not in TaskState.shutting_down():
task.state = TaskState.running
if node.state not in NodeState.ready_for_reset():
node.state = NodeState.busy
node.save()
node_task.save()
if isinstance(envelope.event, NodeStateUpdate):
return on_state_update(envelope.machine_id, envelope.event)
# Start the clock for the task if it wasn't started already
# (as happens in 1.0.0 agents)
task.on_start()
elif event.done:
exit_status = event.done.exit_status
if not exit_status.success:
logging.error("task failed. status:%s", exit_status)
task.mark_failed(
Error(
code=ErrorCode.TASK_FAILED,
errors=[
"task failed. exit_status:%s" % exit_status,
event.done.stdout[-4096:],
event.done.stderr[-4096:],
],
)
)
if task.config.debug and (
TaskDebugFlag.keep_node_on_failure in task.config.debug
or TaskDebugFlag.keep_node_on_completion in task.config.debug
):
node.debug_keep_node = True
node.save()
if isinstance(envelope.event, WorkerEvent):
return on_worker_event(envelope.machine_id, envelope.event)
else:
task.mark_stopping()
if (
task.config.debug
and TaskDebugFlag.keep_node_on_completion in task.config.debug
):
node.debug_keep_node = True
node.save()
if isinstance(envelope.event, NodeEvent):
if envelope.event.state_update:
result = on_state_update(envelope.machine_id, envelope.event.state_update)
if result is not None:
return result
node.to_reimage(done=True)
else:
err = Error(
code=ErrorCode.INVALID_REQUEST,
errors=["invalid worker event type"],
)
raise RequestException(err)
if envelope.event.worker_event:
result = on_worker_event(envelope.machine_id, envelope.event.worker_event)
if result is not None:
return result
task.save()
return None
task_event = TaskEvent(task_id=task_id, machine_id=machine_id, event_data=event)
task_event.save()
raise NotImplementedError("invalid node event: %s" % envelope)
def post(req: func.HttpRequest) -> func.HttpResponse:
envelope = parse_request(NodeEventEnvelope, req)
if isinstance(envelope, Error):
return not_ok(envelope, context=ERROR_CONTEXT)
return not_ok(envelope, context="node event")
logging.info(
"node event: machine_id: %s event: %s",
@ -218,34 +61,15 @@ def post(req: func.HttpRequest) -> func.HttpResponse:
envelope.event.json(exclude_none=True),
)
if isinstance(envelope.event, NodeEvent):
event = envelope.event
elif isinstance(envelope.event, NodeStateUpdate):
event = NodeEvent(state_update=envelope.event)
elif isinstance(envelope.event, WorkerEvent):
event = NodeEvent(worker_event=envelope.event)
else:
err = Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid node event"])
return not_ok(err, context=ERROR_CONTEXT)
result = process(envelope)
if isinstance(result, Error):
logging.error(
"unable to process agent event. envelope:%s error:%s", envelope, result
)
return not_ok(result, context="node event")
if event.state_update:
on_state_update(envelope.machine_id, event.state_update)
return ok(BoolResult(result=True))
elif event.worker_event:
on_worker_event(envelope.machine_id, event.worker_event)
return ok(BoolResult(result=True))
else:
err = Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid node event"])
return not_ok(err, context=ERROR_CONTEXT)
return ok(BoolResult(result=True))
def main(req: func.HttpRequest) -> func.HttpResponse:
try:
if req.method == "POST":
m = post
else:
raise Exception("invalid method")
return verify_token(req, m)
except RequestException as r:
return not_ok(r.error, context=ERROR_CONTEXT)
return verify_token(req, post)

View File

@ -7,16 +7,14 @@
"direction": "in",
"name": "req",
"methods": [
"get",
"post",
"delete"
"post"
],
"route": "agents/events"
},
{
},
{
"type": "http",
"direction": "out",
"name": "$return"
}
}
]
}
}

View File

@ -0,0 +1,252 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from typing import Optional, cast
from uuid import UUID
from onefuzztypes.enums import (
ErrorCode,
NodeState,
NodeTaskState,
TaskDebugFlag,
TaskState,
)
from onefuzztypes.models import (
Error,
NodeDoneEventData,
NodeSettingUpEventData,
NodeStateUpdate,
Result,
WorkerDoneEvent,
WorkerEvent,
WorkerRunningEvent,
)
from ..onefuzzlib.pools import Node, NodeTasks
from ..onefuzzlib.task_event import TaskEvent
from ..onefuzzlib.tasks.main import Task
def get_node(machine_id: UUID) -> Result[Node]:
node = Node.get_by_machine_id(machine_id)
if not node:
return Error(code=ErrorCode.INVALID_NODE, errors=["unable to find node"])
return node
def on_state_update(
machine_id: UUID,
state_update: NodeStateUpdate,
) -> Result[None]:
state = state_update.state
node = get_node(machine_id)
if isinstance(node, Error):
return node
if state == NodeState.free:
if node.reimage_requested or node.delete_requested:
logging.info("stopping free node with reset flags: %s", node.machine_id)
node.stop()
return None
if node.could_shrink_scaleset():
logging.info("stopping free node to resize scaleset: %s", node.machine_id)
node.set_halt()
return None
if state == NodeState.init:
if node.delete_requested:
logging.info("stopping node (init and delete_requested): %s", machine_id)
node.stop()
return None
# not checking reimage_requested, as nodes only send 'init' state once. If
# they send 'init' with reimage_requested, it's because the node was reimaged
# successfully.
node.reimage_requested = False
node.state = state
node.save()
return None
logging.info("node state update: %s from:%s to:%s", machine_id, node.state, state)
node.state = state
node.save()
if state == NodeState.free:
logging.info("node now available for work: %s", machine_id)
elif state == NodeState.setting_up:
# Model-validated.
#
# This field will be required in the future.
# For now, it is optional for back compat.
setting_up_data = cast(
Optional[NodeSettingUpEventData],
state_update.data,
)
if setting_up_data:
if not setting_up_data.tasks:
return Error(
code=ErrorCode.INVALID_REQUEST,
errors=["setup without tasks. machine_id: %s", str(machine_id)],
)
for task_id in setting_up_data.tasks:
task = Task.get_by_task_id(task_id)
if isinstance(task, Error):
return task
logging.info(
"node starting task. machine_id: %s job_id: %s task_id: %s",
machine_id,
task.job_id,
task.task_id,
)
# The task state may be `running` if it has `vm_count` > 1, and
# another node is concurrently executing the task. If so, leave
# the state as-is, to represent the max progress made.
#
# Other states we would want to preserve are excluded by the
# outermost conditional check.
if task.state not in [TaskState.running, TaskState.setting_up]:
task.state = TaskState.setting_up
task.save()
task.on_start()
# Note: we set the node task state to `setting_up`, even though
# the task itself may be `running`.
node_task = NodeTasks(
machine_id=machine_id,
task_id=task_id,
state=NodeTaskState.setting_up,
)
node_task.save()
elif state == NodeState.done:
# if tasks are running on the node when it reports as Done
# those are stopped early
node.mark_tasks_stopped_early()
# Model-validated.
#
# This field will be required in the future.
# For now, it is optional for back compat.
done_data = cast(Optional[NodeDoneEventData], state_update.data)
if done_data:
# TODO: do something with this done data
if done_data.error:
logging.error(
"node 'done' with error: machine_id:%s, data:%s",
machine_id,
done_data,
)
return None
def on_worker_event_running(
machine_id: UUID, event: WorkerRunningEvent
) -> Result[None]:
task = Task.get_by_task_id(event.task_id)
if isinstance(task, Error):
return task
node = get_node(machine_id)
if isinstance(node, Error):
return node
if node.state not in NodeState.ready_for_reset():
node.state = NodeState.busy
node.save()
node_task = NodeTasks(
machine_id=machine_id, task_id=event.task_id, state=NodeTaskState.running
)
node_task.save()
if task.state in TaskState.shutting_down():
logging.info(
"ignoring task start from node. machine_id:%s %s:%s (state: %s)",
machine_id,
task.job_id,
task.task_id,
task.state,
)
return None
logging.info(
"task started on node. machine_id:%s %s:%s",
machine_id,
task.job_id,
task.task_id,
)
task.state = TaskState.running
task.save()
# Start the clock for the task if it wasn't started already
# (as happens in 1.0.0 agents)
task.on_start()
return None
def on_worker_event_done(machine_id: UUID, event: WorkerDoneEvent) -> Result[None]:
task = Task.get_by_task_id(event.task_id)
if isinstance(task, Error):
return task
node = get_node(machine_id)
if isinstance(node, Error):
return node
if event.exit_status.success:
logging.info(
"task done. %s:%s status:%s", task.job_id, task.task_id, event.exit_status
)
task.mark_stopping()
if (
task.config.debug
and TaskDebugFlag.keep_node_on_completion in task.config.debug
):
node.debug_keep_node = True
node.save()
else:
logging.error(
"task failed. %s:%s status:%s", task.job_id, task.task_id, event.exit_status
)
task.mark_failed(
Error(
code=ErrorCode.TASK_FAILED,
errors=[
"task failed. exit_status:%s" % event.exit_status,
event.stdout[-4096:],
event.stderr[-4096:],
],
)
)
if task.config.debug and (
TaskDebugFlag.keep_node_on_failure in task.config.debug
or TaskDebugFlag.keep_node_on_completion in task.config.debug
):
node.debug_keep_node = True
node.save()
node.to_reimage(done=True)
task_event = TaskEvent(
task_id=task.task_id, machine_id=machine_id, event_data=WorkerEvent(done=event)
)
task_event.save()
return None
def on_worker_event(machine_id: UUID, event: WorkerEvent) -> Result[None]:
if event.running:
return on_worker_event_running(machine_id, event.running)
elif event.done:
return on_worker_event_done(machine_id, event.done)
else:
raise NotImplementedError

View File

@ -4,7 +4,7 @@
# Licensed under the MIT License.
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
from uuid import UUID, uuid4
from pydantic import BaseModel, Field, root_validator, validator
@ -58,6 +58,10 @@ class Error(BaseModel):
errors: List[str]
OkType = TypeVar("OkType")
Result = Union[OkType, Error]
class FileEntry(BaseModel):
container: Container
filename: str