Add stop node command (#32)

This commit is contained in:
Joe Ranweiler 2020-09-28 07:34:53 -07:00 committed by GitHub
parent ab41b8986b
commit 1ab55d942f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 75 additions and 10 deletions

View File

@ -18,10 +18,10 @@ pub struct StopTask {
} }
#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)] #[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[serde(tag = "command_type")] #[serde(rename_all = "snake_case")]
pub enum NodeCommand { pub enum NodeCommand {
#[serde(alias = "stop")]
StopTask(StopTask), StopTask(StopTask),
Stop {},
} }
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] #[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
@ -35,6 +35,17 @@ pub struct PendingNodeCommand {
envelope: Option<NodeCommandEnvelope>, envelope: Option<NodeCommandEnvelope>,
} }
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct PollCommandsRequest {
machine_id: Uuid,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct ClaimNodeCommandRequest {
machine_id: Uuid,
message_id: String,
}
#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)] #[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum NodeState { pub enum NodeState {
@ -150,7 +161,9 @@ impl Coordinator {
let pending: PendingNodeCommand = serde_json::from_slice(&data)?; let pending: PendingNodeCommand = serde_json::from_slice(&data)?;
if let Some(envelope) = pending.envelope { if let Some(envelope) = pending.envelope {
// TODO: DELETE dequeued command via `message_id`. let request = RequestType::ClaimCommand(envelope.message_id);
self.send_with_auth_retry(request).await?;
Ok(Some(envelope.command)) Ok(Some(envelope.command))
} else { } else {
Ok(None) Ok(None)
@ -188,7 +201,7 @@ impl Coordinator {
&mut self, &mut self,
request_type: RequestType<'a>, request_type: RequestType<'a>,
) -> Result<Response> { ) -> Result<Response> {
let request = self.build_request(request_type)?; let request = self.build_request(request_type.clone())?;
let mut response = self.client.execute(request).await?; let mut response = self.client.execute(request).await?;
if response.status() == StatusCode::UNAUTHORIZED { if response.status() == StatusCode::UNAUTHORIZED {
@ -214,17 +227,40 @@ impl Coordinator {
fn build_request<'a>(&self, request_type: RequestType<'a>) -> Result<Request> { fn build_request<'a>(&self, request_type: RequestType<'a>) -> Result<Request> {
match request_type { match request_type {
RequestType::PollCommands => self.poll_commands_request(), RequestType::PollCommands => self.poll_commands_request(),
RequestType::ClaimCommand(message_id) => self.claim_command_request(message_id),
RequestType::EmitEvent(event) => self.emit_event_request(event), RequestType::EmitEvent(event) => self.emit_event_request(event),
RequestType::CanSchedule(work_set) => self.can_schedule_request(work_set), RequestType::CanSchedule(work_set) => self.can_schedule_request(work_set),
} }
} }
fn poll_commands_request(&self) -> Result<Request> { fn poll_commands_request(&self) -> Result<Request> {
let request = PollCommandsRequest {
machine_id: self.registration.machine_id,
};
let url = self.registration.dynamic_config.commands_url.clone(); let url = self.registration.dynamic_config.commands_url.clone();
let request = self let request = self
.client .client
.get(url) .get(url)
.bearer_auth(self.token.secret().expose()) .bearer_auth(self.token.secret().expose())
.json(&request)
.build()?;
Ok(request)
}
fn claim_command_request(&self, message_id: String) -> Result<Request> {
let request = ClaimNodeCommandRequest {
machine_id: self.registration.machine_id,
message_id,
};
let url = self.registration.dynamic_config.commands_url.clone();
let request = self
.client
.delete(url)
.bearer_auth(self.token.secret().expose())
.json(&request)
.build()?; .build()?;
Ok(request) Ok(request)
@ -271,9 +307,10 @@ impl Coordinator {
// The upstream `Request` type is not `Clone`, so we can't retry a request // The upstream `Request` type is not `Clone`, so we can't retry a request
// without rebuilding it. We use this enum to dispatch to a private method, // without rebuilding it. We use this enum to dispatch to a private method,
// avoiding borrowck conflicts that occur when capturing `self`. // avoiding borrowck conflicts that occur when capturing `self`.
#[derive(Copy, Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
enum RequestType<'a> { enum RequestType<'a> {
PollCommands, PollCommands,
ClaimCommand(String),
EmitEvent(&'a NodeEventEnvelope), EmitEvent(&'a NodeEventEnvelope),
CanSchedule(&'a WorkSet), CanSchedule(&'a WorkSet),
} }

View File

@ -46,6 +46,10 @@ impl Scheduler {
state.stop(stop_task.task_id)?; state.stop(stop_task.task_id)?;
} }
} }
NodeCommand::Stop {} => {
let state = State { ctx: Done {} };
*self = state.into();
}
} }
Ok(()) Ok(())

View File

@ -84,7 +84,7 @@ impl<'a> Tester<'a> {
.map(|f| f.to_string()) .map(|f| f.to_string())
.collect(); .collect();
let crash_site = if let Some(frame) = call_stack.iter().next() { let crash_site = if let Some(frame) = call_stack.get(0) {
frame.to_string() frame.to_string()
} else { } else {
CRASH_SITE_UNAVAILABLE.to_owned() CRASH_SITE_UNAVAILABLE.to_owned()

View File

@ -4,10 +4,10 @@
# Licensed under the MIT License. # Licensed under the MIT License.
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, root_validator, validator
from .consts import ONE_HOUR, SEVEN_DAYS from .consts import ONE_HOUR, SEVEN_DAYS
from .enums import ( from .enums import (
@ -32,6 +32,24 @@ from .enums import (
from .primitives import Container, PoolName, Region from .primitives import Container, PoolName, Region
class EnumModel(BaseModel):
@root_validator(pre=True)
def exactly_one(cls: Any, values: Any) -> Any:
some = []
for field, val in values.items():
if val is not None:
some.append(field)
if not some:
raise ValueError('no variant set for enum')
if len(some) > 1:
raise ValueError('multiple values set for enum: %s' % some)
return values
class Error(BaseModel): class Error(BaseModel):
code: ErrorCode code: ErrorCode
errors: List[str] errors: List[str]
@ -492,11 +510,17 @@ class NodeEventEnvelope(BaseModel):
event: NodeEvent event: NodeEvent
class NodeCommandStopTask(BaseModel): class StopNodeCommand(BaseModel):
pass
class StopTaskNodeCommand(BaseModel):
task_id: UUID task_id: UUID
NodeCommand = Union[NodeCommandStopTask] class NodeCommand(EnumModel):
stop: Optional[StopNodeCommand]
stop_task: Optional[StopTaskNodeCommand]
class NodeCommandEnvelope(BaseModel): class NodeCommandEnvelope(BaseModel):