Add User Info to created tasks (#303)

This PR makes user information from JWT tokens available as part of a Task.

Included changes:
* Renamed `verify_token` to `call_if_agent`, since this function is specific to agent token verification
* Renames `is_authorized` to `is_agent`, since this function checks if the token is an agent
* Adds support for unmanaged nodes in `is_agent` (see #133 for information) 
* Saves the user information from the JWT token on task create as part of `TaskConfig`

Note, `TaskConfig` is what is provided to notification templates.  This enables Github issues and ADO work items to tie back to the user that created the task.

Note, while `upn` _usually_ means email for AAD user tokens.  If we were going to make use of the email address, we should perform a graph lookup based on the `oid`, but we're not.
This commit is contained in:
bmc-msft
2020-11-13 06:50:52 -05:00
committed by GitHub
parent 31f099d3d4
commit beea318968
13 changed files with 254 additions and 71 deletions

View File

@ -54,6 +54,11 @@ Each event will be submitted via HTTP POST to the user provided URL.
} }
], ],
"tags": {} "tags": {}
},
"user_info": {
"application_id": "00000000-0000-0000-0000-000000000000",
"object_id": "00000000-0000-0000-0000-000000000000",
"upn": "example@contoso.com"
} }
} }
``` ```
@ -77,6 +82,9 @@ Each event will be submitted via HTTP POST to the user provided URL.
}, },
"config": { "config": {
"$ref": "#/definitions/TaskConfig" "$ref": "#/definitions/TaskConfig"
},
"user_info": {
"$ref": "#/definitions/UserInfo"
} }
}, },
"required": [ "required": [
@ -401,6 +409,29 @@ Each event will be submitted via HTTP POST to the user provided URL.
"containers", "containers",
"tags" "tags"
] ]
},
"UserInfo": {
"title": "UserInfo",
"type": "object",
"properties": {
"application_id": {
"title": "Application Id",
"type": "string",
"format": "uuid"
},
"object_id": {
"title": "Object Id",
"type": "string",
"format": "uuid"
},
"upn": {
"title": "Upn",
"type": "string"
}
},
"required": [
"application_id"
]
} }
} }
} }
@ -413,7 +444,12 @@ Each event will be submitted via HTTP POST to the user provided URL.
```json ```json
{ {
"job_id": "00000000-0000-0000-0000-000000000000", "job_id": "00000000-0000-0000-0000-000000000000",
"task_id": "00000000-0000-0000-0000-000000000000" "task_id": "00000000-0000-0000-0000-000000000000",
"user_info": {
"application_id": "00000000-0000-0000-0000-000000000000",
"object_id": "00000000-0000-0000-0000-000000000000",
"upn": "example@contoso.com"
}
} }
``` ```
@ -433,12 +469,40 @@ Each event will be submitted via HTTP POST to the user provided URL.
"title": "Task Id", "title": "Task Id",
"type": "string", "type": "string",
"format": "uuid" "format": "uuid"
},
"user_info": {
"$ref": "#/definitions/UserInfo"
} }
}, },
"required": [ "required": [
"job_id", "job_id",
"task_id" "task_id"
] ],
"definitions": {
"UserInfo": {
"title": "UserInfo",
"type": "object",
"properties": {
"application_id": {
"title": "Application Id",
"type": "string",
"format": "uuid"
},
"object_id": {
"title": "Object Id",
"type": "string",
"format": "uuid"
},
"upn": {
"title": "Upn",
"type": "string"
}
},
"required": [
"application_id"
]
}
}
} }
``` ```
@ -455,6 +519,11 @@ Each event will be submitted via HTTP POST to the user provided URL.
"errors": [ "errors": [
"example error message" "example error message"
] ]
},
"user_info": {
"application_id": "00000000-0000-0000-0000-000000000000",
"object_id": "00000000-0000-0000-0000-000000000000",
"upn": "example@contoso.com"
} }
} }
``` ```
@ -478,6 +547,9 @@ Each event will be submitted via HTTP POST to the user provided URL.
}, },
"error": { "error": {
"$ref": "#/definitions/Error" "$ref": "#/definitions/Error"
},
"user_info": {
"$ref": "#/definitions/UserInfo"
} }
}, },
"required": [ "required": [
@ -531,6 +603,29 @@ Each event will be submitted via HTTP POST to the user provided URL.
"code", "code",
"errors" "errors"
] ]
},
"UserInfo": {
"title": "UserInfo",
"type": "object",
"properties": {
"application_id": {
"title": "Application Id",
"type": "string",
"format": "uuid"
},
"object_id": {
"title": "Object Id",
"type": "string",
"format": "uuid"
},
"upn": {
"title": "Upn",
"type": "string"
}
},
"required": [
"application_id"
]
} }
} }
} }
@ -933,6 +1028,29 @@ Each event will be submitted via HTTP POST to the user provided URL.
"tags" "tags"
] ]
}, },
"UserInfo": {
"title": "UserInfo",
"type": "object",
"properties": {
"application_id": {
"title": "Application Id",
"type": "string",
"format": "uuid"
},
"object_id": {
"title": "Object Id",
"type": "string",
"format": "uuid"
},
"upn": {
"title": "Upn",
"type": "string"
}
},
"required": [
"application_id"
]
},
"WebhookEventTaskCreated": { "WebhookEventTaskCreated": {
"title": "WebhookEventTaskCreated", "title": "WebhookEventTaskCreated",
"type": "object", "type": "object",
@ -949,6 +1067,9 @@ Each event will be submitted via HTTP POST to the user provided URL.
}, },
"config": { "config": {
"$ref": "#/definitions/TaskConfig" "$ref": "#/definitions/TaskConfig"
},
"user_info": {
"$ref": "#/definitions/UserInfo"
} }
}, },
"required": [ "required": [
@ -970,6 +1091,9 @@ Each event will be submitted via HTTP POST to the user provided URL.
"title": "Task Id", "title": "Task Id",
"type": "string", "type": "string",
"format": "uuid" "format": "uuid"
},
"user_info": {
"$ref": "#/definitions/UserInfo"
} }
}, },
"required": [ "required": [
@ -1039,6 +1163,9 @@ Each event will be submitted via HTTP POST to the user provided URL.
}, },
"error": { "error": {
"$ref": "#/definitions/Error" "$ref": "#/definitions/Error"
},
"user_info": {
"$ref": "#/definitions/UserInfo"
} }
}, },
"required": [ "required": [

View File

@ -9,7 +9,7 @@ from onefuzztypes.models import Error
from onefuzztypes.requests import CanScheduleRequest from onefuzztypes.requests import CanScheduleRequest
from onefuzztypes.responses import CanSchedule from onefuzztypes.responses import CanSchedule
from ..onefuzzlib.agent_authorization import verify_token from ..onefuzzlib.agent_authorization import call_if_agent
from ..onefuzzlib.pools import Node from ..onefuzzlib.pools import Node
from ..onefuzzlib.request import not_ok, ok, parse_request from ..onefuzzlib.request import not_ok, ok, parse_request
from ..onefuzzlib.tasks.main import Task from ..onefuzzlib.tasks.main import Task
@ -48,4 +48,4 @@ def main(req: func.HttpRequest) -> func.HttpResponse:
else: else:
raise Exception("invalid method") raise Exception("invalid method")
return verify_token(req, m) return call_if_agent(req, m)

View File

@ -8,7 +8,7 @@ from onefuzztypes.models import Error, NodeCommandEnvelope
from onefuzztypes.requests import NodeCommandDelete, NodeCommandGet from onefuzztypes.requests import NodeCommandDelete, NodeCommandGet
from onefuzztypes.responses import BoolResult, PendingNodeCommand from onefuzztypes.responses import BoolResult, PendingNodeCommand
from ..onefuzzlib.agent_authorization import verify_token from ..onefuzzlib.agent_authorization import call_if_agent
from ..onefuzzlib.pools import NodeMessage from ..onefuzzlib.pools import NodeMessage
from ..onefuzzlib.request import not_ok, ok, parse_request from ..onefuzzlib.request import not_ok, ok, parse_request
@ -50,4 +50,4 @@ def main(req: func.HttpRequest) -> func.HttpResponse:
else: else:
raise Exception("invalid method") raise Exception("invalid method")
return verify_token(req, m) return call_if_agent(req, m)

View File

@ -16,7 +16,7 @@ from onefuzztypes.models import (
) )
from onefuzztypes.responses import BoolResult from onefuzztypes.responses import BoolResult
from ..onefuzzlib.agent_authorization import verify_token from ..onefuzzlib.agent_authorization import call_if_agent
from ..onefuzzlib.agent_events import on_state_update, on_worker_event from ..onefuzzlib.agent_events import on_state_update, on_worker_event
from ..onefuzzlib.request import not_ok, ok, parse_request from ..onefuzzlib.request import not_ok, ok, parse_request
@ -72,4 +72,4 @@ def post(req: func.HttpRequest) -> func.HttpResponse:
def main(req: func.HttpRequest) -> func.HttpResponse: def main(req: func.HttpRequest) -> func.HttpResponse:
return verify_token(req, post) return call_if_agent(req, post)

View File

@ -12,7 +12,7 @@ from onefuzztypes.models import Error
from onefuzztypes.requests import AgentRegistrationGet, AgentRegistrationPost from onefuzztypes.requests import AgentRegistrationGet, AgentRegistrationPost
from onefuzztypes.responses import AgentRegistration from onefuzztypes.responses import AgentRegistration
from ..onefuzzlib.agent_authorization import verify_token from ..onefuzzlib.agent_authorization import call_if_agent
from ..onefuzzlib.azure.creds import get_fuzz_storage, get_instance_url from ..onefuzzlib.azure.creds import get_fuzz_storage, get_instance_url
from ..onefuzzlib.azure.queue import get_queue_sas from ..onefuzzlib.azure.queue import get_queue_sas
from ..onefuzzlib.pools import Node, NodeMessage, NodeTasks, Pool from ..onefuzzlib.pools import Node, NodeMessage, NodeTasks, Pool
@ -122,4 +122,4 @@ def main(req: func.HttpRequest) -> func.HttpResponse:
else: else:
raise Exception("invalid method") raise Exception("invalid method")
return verify_token(req, m) return call_if_agent(req, m)

View File

@ -4,77 +4,50 @@
# Licensed under the MIT License. # Licensed under the MIT License.
import logging import logging
from typing import Callable, Union from typing import Callable
from uuid import UUID from uuid import UUID
import azure.functions as func import azure.functions as func
import jwt
from memoization import cached from memoization import cached
from onefuzztypes.enums import ErrorCode from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error from onefuzztypes.models import Error, UserInfo
from pydantic import BaseModel
from .azure.creds import get_scaleset_principal_id from .azure.creds import get_scaleset_principal_id
from .pools import Scaleset from .pools import Pool, Scaleset
from .request import not_ok from .request import not_ok
from .user_credentials import parse_jwt_token
class TokenData(BaseModel):
application_id: UUID
object_id: UUID
def try_get_token_auth_header(request: func.HttpRequest) -> Union[Error, TokenData]:
""" Obtains the Access Token from the Authorization Header """
auth: str = request.headers.get("Authorization", None)
if not auth:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["Authorization header is expected"]
)
parts = auth.split()
if parts[0].lower() != "bearer":
return Error(
code=ErrorCode.INVALID_REQUEST,
errors=["Authorization header must start with Bearer"],
)
elif len(parts) == 1:
return Error(code=ErrorCode.INVALID_REQUEST, errors=["Token not found"])
elif len(parts) > 2:
return Error(
code=ErrorCode.INVALID_REQUEST,
errors=["Authorization header must be Bearer token"],
)
# This token has already been verified by the azure authentication layer
token = jwt.decode(parts[1], verify=False)
return TokenData(application_id=UUID(token["appid"]), object_id=UUID(token["oid"]))
@cached(ttl=60) @cached(ttl=60)
def is_authorized(token_data: TokenData) -> bool: def is_agent(token_data: UserInfo) -> bool:
# backward compatibility case for scalesets deployed before the migration
# to user assigned managed id if token_data.object_id:
scalesets = Scaleset.get_by_object_id(token_data.object_id) # backward compatibility case for scalesets deployed before the migration
if len(scalesets) > 0: # to user assigned managed id
scalesets = Scaleset.get_by_object_id(token_data.object_id)
if len(scalesets) > 0:
return True
# verify object_id against the user assigned managed identity
principal_id: UUID = get_scaleset_principal_id()
return principal_id == token_data.object_id
pools = Pool.search(query={"client_id": [token_data.application_id]})
if len(pools) > 0:
return True return True
# verify object_id against the user assigned managed identity return False
principal_id: UUID = get_scaleset_principal_id()
return principal_id == token_data.object_id
def verify_token( def call_if_agent(
req: func.HttpRequest, method: Callable[[func.HttpRequest], func.HttpResponse] req: func.HttpRequest, method: Callable[[func.HttpRequest], func.HttpResponse]
) -> func.HttpResponse: ) -> func.HttpResponse:
token = try_get_token_auth_header(req)
token = parse_jwt_token(req)
if isinstance(token, Error): if isinstance(token, Error):
return not_ok(token, status_code=401, context="token verification") return not_ok(token, status_code=401, context="token verification")
if not is_authorized(token): if not is_agent(token):
logging.error( logging.error(
"rejecting token url:%s token:%s body:%s", "rejecting token url:%s token:%s body:%s",
repr(req.url), repr(req.url),

View File

@ -11,7 +11,7 @@ from uuid import UUID
from onefuzztypes.enums import ErrorCode, TaskState from onefuzztypes.enums import ErrorCode, TaskState
from onefuzztypes.models import Error from onefuzztypes.models import Error
from onefuzztypes.models import Task as BASE_TASK from onefuzztypes.models import Task as BASE_TASK
from onefuzztypes.models import TaskConfig, TaskVm from onefuzztypes.models import TaskConfig, TaskVm, UserInfo
from onefuzztypes.webhooks import ( from onefuzztypes.webhooks import (
WebhookEventTaskCreated, WebhookEventTaskCreated,
WebhookEventTaskFailed, WebhookEventTaskFailed,
@ -42,7 +42,9 @@ class Task(BASE_TASK, ORMMixin):
return True return True
@classmethod @classmethod
def create(cls, config: TaskConfig, job_id: UUID) -> Union["Task", Error]: def create(
cls, config: TaskConfig, job_id: UUID, user_info: UserInfo
) -> Union["Task", Error]:
if config.vm: if config.vm:
os = get_os(config.vm.region, config.vm.image) os = get_os(config.vm.region, config.vm.image)
elif config.pool: elif config.pool:
@ -52,11 +54,14 @@ class Task(BASE_TASK, ORMMixin):
os = pool.os os = pool.os
else: else:
raise Exception("task must have vm or pool") raise Exception("task must have vm or pool")
task = cls(config=config, job_id=job_id, os=os) task = cls(config=config, job_id=job_id, os=os, user_info=user_info)
task.save() task.save()
Webhook.send_event( Webhook.send_event(
WebhookEventTaskCreated( WebhookEventTaskCreated(
job_id=task.job_id, task_id=task.task_id, config=config job_id=task.job_id,
task_id=task.task_id,
config=config,
user_info=user_info,
) )
) )
return task return task
@ -183,7 +188,9 @@ class Task(BASE_TASK, ORMMixin):
self.state = TaskState.stopping self.state = TaskState.stopping
self.save() self.save()
Webhook.send_event( Webhook.send_event(
WebhookEventTaskStopped(job_id=self.job_id, task_id=self.task_id) WebhookEventTaskStopped(
job_id=self.job_id, task_id=self.task_id, user_info=self.user_info
)
) )
def mark_failed(self, error: Error) -> None: def mark_failed(self, error: Error) -> None:
@ -199,7 +206,10 @@ class Task(BASE_TASK, ORMMixin):
Webhook.send_event( Webhook.send_event(
WebhookEventTaskFailed( WebhookEventTaskFailed(
job_id=self.job_id, task_id=self.task_id, error=error job_id=self.job_id,
task_id=self.task_id,
error=error,
user_info=self.user_info,
) )
) )

View File

@ -0,0 +1,42 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from uuid import UUID
import azure.functions as func
import jwt
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error, Result, UserInfo
def parse_jwt_token(request: func.HttpRequest) -> Result[UserInfo]:
""" Obtains the Access Token from the Authorization Header """
auth: str = request.headers.get("Authorization", None)
if not auth:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["Authorization header is expected"]
)
parts = auth.split()
if len(parts) != 2:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["Invalid authorization header"]
)
if parts[0].lower() != "bearer":
return Error(
code=ErrorCode.INVALID_REQUEST,
errors=["Authorization header must start with Bearer"],
)
# This token has already been verified by the azure authentication layer
token = jwt.decode(parts[1], verify=False)
application_id = UUID(token["appid"])
object_id = UUID(token["oid"]) if "oid" in token else None
upn = token.get("upn")
return UserInfo(application_id=application_id, object_id=object_id, upn=upn)

View File

@ -231,7 +231,7 @@ def build_message(
WebhookMessage( WebhookMessage(
webhook_id=webhook_id, event_id=event_id, event_type=event_type, event=event webhook_id=webhook_id, event_id=event_id, event_type=event_type, event=event
) )
.json(sort_keys=True) .json(sort_keys=True, exclude_none=True)
.encode() .encode()
) )
digest = None digest = None

View File

@ -16,6 +16,7 @@ from ..onefuzzlib.request import not_ok, ok, parse_request
from ..onefuzzlib.task_event import TaskEvent from ..onefuzzlib.task_event import TaskEvent
from ..onefuzzlib.tasks.config import TaskConfigError, check_config from ..onefuzzlib.tasks.config import TaskConfigError, check_config
from ..onefuzzlib.tasks.main import Task from ..onefuzzlib.tasks.main import Task
from ..onefuzzlib.user_credentials import parse_jwt_token
def post(req: func.HttpRequest) -> func.HttpResponse: def post(req: func.HttpRequest) -> func.HttpResponse:
@ -23,6 +24,10 @@ def post(req: func.HttpRequest) -> func.HttpResponse:
if isinstance(request, Error): if isinstance(request, Error):
return not_ok(request, context="task create") return not_ok(request, context="task create")
user_info = parse_jwt_token(req)
if isinstance(user_info, Error):
return not_ok(user_info, context="task create")
try: try:
check_config(request) check_config(request)
except TaskConfigError as err: except TaskConfigError as err:
@ -56,7 +61,7 @@ def post(req: func.HttpRequest) -> func.HttpResponse:
if isinstance(prereq, Error): if isinstance(prereq, Error):
return not_ok(prereq, context="task create prerequisite") return not_ok(prereq, context="task create prerequisite")
task = Task.create(config=request, job_id=request.job_id) task = Task.create(config=request, job_id=request.job_id, user_info=user_info)
if isinstance(task, Error): if isinstance(task, Error):
return not_ok(task, context="task create invalid pool") return not_ok(task, context="task create invalid pool")
return ok(task) return ok(task)

View File

@ -7,7 +7,7 @@ from typing import Optional
from uuid import UUID from uuid import UUID
import json import json
from onefuzztypes.enums import TaskType, ContainerType, ErrorCode from onefuzztypes.enums import TaskType, ContainerType, ErrorCode
from onefuzztypes.models import TaskConfig, TaskDetails, TaskContainers, Error from onefuzztypes.models import TaskConfig, TaskDetails, TaskContainers, Error, UserInfo
from onefuzztypes.webhooks import ( from onefuzztypes.webhooks import (
WebhookMessage, WebhookMessage,
WebhookEventPing, WebhookEventPing,
@ -32,12 +32,23 @@ def main():
examples = { examples = {
WebhookEventType.ping: WebhookEventPing(ping_id=UUID(int=0)), WebhookEventType.ping: WebhookEventPing(ping_id=UUID(int=0)),
WebhookEventType.task_stopped: WebhookEventTaskStopped( WebhookEventType.task_stopped: WebhookEventTaskStopped(
job_id=UUID(int=0), task_id=UUID(int=0) job_id=UUID(int=0),
task_id=UUID(int=0),
user_info=UserInfo(
application_id=UUID(int=0),
object_id=UUID(int=0),
upn="example@contoso.com",
),
), ),
WebhookEventType.task_failed: WebhookEventTaskFailed( WebhookEventType.task_failed: WebhookEventTaskFailed(
job_id=UUID(int=0), job_id=UUID(int=0),
task_id=UUID(int=0), task_id=UUID(int=0),
error=Error(code=ErrorCode.TASK_FAILED, errors=["example error message"]), error=Error(code=ErrorCode.TASK_FAILED, errors=["example error message"]),
user_info=UserInfo(
application_id=UUID(int=0),
object_id=UUID(int=0),
upn="example@contoso.com",
),
), ),
WebhookEventType.task_created: WebhookEventTaskCreated( WebhookEventType.task_created: WebhookEventTaskCreated(
job_id=UUID(int=0), job_id=UUID(int=0),
@ -58,6 +69,11 @@ def main():
], ],
tags={}, tags={},
), ),
user_info=UserInfo(
application_id=UUID(int=0),
object_id=UUID(int=0),
upn="example@contoso.com",
),
), ),
} }

View File

@ -35,6 +35,12 @@ from .enums import (
from .primitives import Container, PoolName, Region from .primitives import Container, PoolName, Region
class UserInfo(BaseModel):
application_id: UUID
object_id: Optional[UUID]
upn: Optional[str]
class EnumModel(BaseModel): class EnumModel(BaseModel):
@root_validator(pre=True) @root_validator(pre=True)
def exactly_one(cls: Any, values: Any) -> Any: def exactly_one(cls: Any, values: Any) -> Any:
@ -704,3 +710,4 @@ class Task(BaseModel):
end_time: Optional[datetime] end_time: Optional[datetime]
events: Optional[List[TaskEventSummary]] events: Optional[List[TaskEventSummary]]
nodes: Optional[List[NodeAssignment]] nodes: Optional[List[NodeAssignment]]
user_info: Optional[UserInfo]

View File

@ -4,25 +4,28 @@ from uuid import UUID, uuid4
from pydantic import AnyHttpUrl, BaseModel, Field from pydantic import AnyHttpUrl, BaseModel, Field
from .enums import WebhookEventType, WebhookMessageState from .enums import WebhookEventType, WebhookMessageState
from .models import Error, TaskConfig from .models import Error, TaskConfig, UserInfo
from .responses import BaseResponse from .responses import BaseResponse
class WebhookEventTaskStopped(BaseModel): class WebhookEventTaskStopped(BaseModel):
job_id: UUID job_id: UUID
task_id: UUID task_id: UUID
user_info: Optional[UserInfo]
class WebhookEventTaskFailed(BaseModel): class WebhookEventTaskFailed(BaseModel):
job_id: UUID job_id: UUID
task_id: UUID task_id: UUID
error: Error error: Error
user_info: Optional[UserInfo]
class WebhookEventTaskCreated(BaseModel): class WebhookEventTaskCreated(BaseModel):
job_id: UUID job_id: UUID
task_id: UUID task_id: UUID
config: TaskConfig config: TaskConfig
user_info: Optional[UserInfo]
class WebhookEventPing(BaseResponse): class WebhookEventPing(BaseResponse):