diff --git a/src/agent/onefuzz-supervisor/src/coordinator.rs b/src/agent/onefuzz-supervisor/src/coordinator.rs index 34ebc6c76..f474d3f75 100644 --- a/src/agent/onefuzz-supervisor/src/coordinator.rs +++ b/src/agent/onefuzz-supervisor/src/coordinator.rs @@ -18,10 +18,10 @@ pub struct StopTask { } #[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)] -#[serde(tag = "command_type")] +#[serde(rename_all = "snake_case")] pub enum NodeCommand { - #[serde(alias = "stop")] StopTask(StopTask), + Stop {}, } #[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] @@ -35,6 +35,17 @@ pub struct PendingNodeCommand { envelope: Option, } +#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] +pub struct PollCommandsRequest { + machine_id: Uuid, +} + +#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] +pub struct ClaimNodeCommandRequest { + machine_id: Uuid, + message_id: String, +} + #[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)] #[serde(rename_all = "snake_case")] pub enum NodeState { @@ -150,7 +161,9 @@ impl Coordinator { let pending: PendingNodeCommand = serde_json::from_slice(&data)?; if let Some(envelope) = pending.envelope { - // TODO: DELETE dequeued command via `message_id`. + let request = RequestType::ClaimCommand(envelope.message_id); + self.send_with_auth_retry(request).await?; + Ok(Some(envelope.command)) } else { Ok(None) @@ -188,7 +201,7 @@ impl Coordinator { &mut self, request_type: RequestType<'a>, ) -> Result { - let request = self.build_request(request_type)?; + let request = self.build_request(request_type.clone())?; let mut response = self.client.execute(request).await?; if response.status() == StatusCode::UNAUTHORIZED { @@ -214,17 +227,40 @@ impl Coordinator { fn build_request<'a>(&self, request_type: RequestType<'a>) -> Result { match request_type { RequestType::PollCommands => self.poll_commands_request(), + RequestType::ClaimCommand(message_id) => self.claim_command_request(message_id), RequestType::EmitEvent(event) => self.emit_event_request(event), RequestType::CanSchedule(work_set) => self.can_schedule_request(work_set), } } fn poll_commands_request(&self) -> Result { + let request = PollCommandsRequest { + machine_id: self.registration.machine_id, + }; + let url = self.registration.dynamic_config.commands_url.clone(); let request = self .client .get(url) .bearer_auth(self.token.secret().expose()) + .json(&request) + .build()?; + + Ok(request) + } + + fn claim_command_request(&self, message_id: String) -> Result { + let request = ClaimNodeCommandRequest { + machine_id: self.registration.machine_id, + message_id, + }; + + let url = self.registration.dynamic_config.commands_url.clone(); + let request = self + .client + .delete(url) + .bearer_auth(self.token.secret().expose()) + .json(&request) .build()?; Ok(request) @@ -271,9 +307,10 @@ impl Coordinator { // The upstream `Request` type is not `Clone`, so we can't retry a request // without rebuilding it. We use this enum to dispatch to a private method, // avoiding borrowck conflicts that occur when capturing `self`. -#[derive(Copy, Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug, Eq, PartialEq)] enum RequestType<'a> { PollCommands, + ClaimCommand(String), EmitEvent(&'a NodeEventEnvelope), CanSchedule(&'a WorkSet), } diff --git a/src/agent/onefuzz-supervisor/src/scheduler.rs b/src/agent/onefuzz-supervisor/src/scheduler.rs index 5f11efeef..4906d7f58 100644 --- a/src/agent/onefuzz-supervisor/src/scheduler.rs +++ b/src/agent/onefuzz-supervisor/src/scheduler.rs @@ -46,6 +46,10 @@ impl Scheduler { state.stop(stop_task.task_id)?; } } + NodeCommand::Stop {} => { + let state = State { ctx: Done {} }; + *self = state.into(); + } } Ok(()) diff --git a/src/agent/onefuzz/src/input_tester.rs b/src/agent/onefuzz/src/input_tester.rs index 9c8db34f4..2353281c6 100644 --- a/src/agent/onefuzz/src/input_tester.rs +++ b/src/agent/onefuzz/src/input_tester.rs @@ -84,7 +84,7 @@ impl<'a> Tester<'a> { .map(|f| f.to_string()) .collect(); - let crash_site = if let Some(frame) = call_stack.iter().next() { + let crash_site = if let Some(frame) = call_stack.get(0) { frame.to_string() } else { CRASH_SITE_UNAVAILABLE.to_owned() diff --git a/src/pytypes/onefuzztypes/models.py b/src/pytypes/onefuzztypes/models.py index 3fdbe5894..145f09ad1 100644 --- a/src/pytypes/onefuzztypes/models.py +++ b/src/pytypes/onefuzztypes/models.py @@ -4,10 +4,10 @@ # Licensed under the MIT License. from datetime import datetime -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from uuid import UUID, uuid4 -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, root_validator, validator from .consts import ONE_HOUR, SEVEN_DAYS from .enums import ( @@ -32,6 +32,24 @@ from .enums import ( from .primitives import Container, PoolName, Region +class EnumModel(BaseModel): + @root_validator(pre=True) + def exactly_one(cls: Any, values: Any) -> Any: + some = [] + + for field, val in values.items(): + if val is not None: + some.append(field) + + if not some: + raise ValueError('no variant set for enum') + + if len(some) > 1: + raise ValueError('multiple values set for enum: %s' % some) + + return values + + class Error(BaseModel): code: ErrorCode errors: List[str] @@ -492,11 +510,17 @@ class NodeEventEnvelope(BaseModel): event: NodeEvent -class NodeCommandStopTask(BaseModel): +class StopNodeCommand(BaseModel): + pass + + +class StopTaskNodeCommand(BaseModel): task_id: UUID -NodeCommand = Union[NodeCommandStopTask] +class NodeCommand(EnumModel): + stop: Optional[StopNodeCommand] + stop_task: Optional[StopTaskNodeCommand] class NodeCommandEnvelope(BaseModel):