Add User Info to created tasks (#303)

This PR makes user information from JWT tokens available as part of a Task.

Included changes:
* Renamed `verify_token` to `call_if_agent`, since this function is specific to agent token verification
* Renames `is_authorized` to `is_agent`, since this function checks if the token is an agent
* Adds support for unmanaged nodes in `is_agent` (see #133 for information) 
* Saves the user information from the JWT token on task create as part of `TaskConfig`

Note, `TaskConfig` is what is provided to notification templates.  This enables Github issues and ADO work items to tie back to the user that created the task.

Note, while `upn` _usually_ means email for AAD user tokens.  If we were going to make use of the email address, we should perform a graph lookup based on the `oid`, but we're not.
This commit is contained in:
bmc-msft
2020-11-13 06:50:52 -05:00
committed by GitHub
parent 31f099d3d4
commit beea318968
13 changed files with 254 additions and 71 deletions

View File

@ -4,77 +4,50 @@
# Licensed under the MIT License.
import logging
from typing import Callable, Union
from typing import Callable
from uuid import UUID
import azure.functions as func
import jwt
from memoization import cached
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error
from pydantic import BaseModel
from onefuzztypes.models import Error, UserInfo
from .azure.creds import get_scaleset_principal_id
from .pools import Scaleset
from .pools import Pool, Scaleset
from .request import not_ok
class TokenData(BaseModel):
application_id: UUID
object_id: UUID
def try_get_token_auth_header(request: func.HttpRequest) -> Union[Error, TokenData]:
""" Obtains the Access Token from the Authorization Header """
auth: str = request.headers.get("Authorization", None)
if not auth:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["Authorization header is expected"]
)
parts = auth.split()
if parts[0].lower() != "bearer":
return Error(
code=ErrorCode.INVALID_REQUEST,
errors=["Authorization header must start with Bearer"],
)
elif len(parts) == 1:
return Error(code=ErrorCode.INVALID_REQUEST, errors=["Token not found"])
elif len(parts) > 2:
return Error(
code=ErrorCode.INVALID_REQUEST,
errors=["Authorization header must be Bearer token"],
)
# This token has already been verified by the azure authentication layer
token = jwt.decode(parts[1], verify=False)
return TokenData(application_id=UUID(token["appid"]), object_id=UUID(token["oid"]))
from .user_credentials import parse_jwt_token
@cached(ttl=60)
def is_authorized(token_data: TokenData) -> bool:
# backward compatibility case for scalesets deployed before the migration
# to user assigned managed id
scalesets = Scaleset.get_by_object_id(token_data.object_id)
if len(scalesets) > 0:
def is_agent(token_data: UserInfo) -> bool:
if token_data.object_id:
# backward compatibility case for scalesets deployed before the migration
# to user assigned managed id
scalesets = Scaleset.get_by_object_id(token_data.object_id)
if len(scalesets) > 0:
return True
# verify object_id against the user assigned managed identity
principal_id: UUID = get_scaleset_principal_id()
return principal_id == token_data.object_id
pools = Pool.search(query={"client_id": [token_data.application_id]})
if len(pools) > 0:
return True
# verify object_id against the user assigned managed identity
principal_id: UUID = get_scaleset_principal_id()
return principal_id == token_data.object_id
return False
def verify_token(
def call_if_agent(
req: func.HttpRequest, method: Callable[[func.HttpRequest], func.HttpResponse]
) -> func.HttpResponse:
token = try_get_token_auth_header(req)
token = parse_jwt_token(req)
if isinstance(token, Error):
return not_ok(token, status_code=401, context="token verification")
if not is_authorized(token):
if not is_agent(token):
logging.error(
"rejecting token url:%s token:%s body:%s",
repr(req.url),

View File

@ -11,7 +11,7 @@ from uuid import UUID
from onefuzztypes.enums import ErrorCode, TaskState
from onefuzztypes.models import Error
from onefuzztypes.models import Task as BASE_TASK
from onefuzztypes.models import TaskConfig, TaskVm
from onefuzztypes.models import TaskConfig, TaskVm, UserInfo
from onefuzztypes.webhooks import (
WebhookEventTaskCreated,
WebhookEventTaskFailed,
@ -42,7 +42,9 @@ class Task(BASE_TASK, ORMMixin):
return True
@classmethod
def create(cls, config: TaskConfig, job_id: UUID) -> Union["Task", Error]:
def create(
cls, config: TaskConfig, job_id: UUID, user_info: UserInfo
) -> Union["Task", Error]:
if config.vm:
os = get_os(config.vm.region, config.vm.image)
elif config.pool:
@ -52,11 +54,14 @@ class Task(BASE_TASK, ORMMixin):
os = pool.os
else:
raise Exception("task must have vm or pool")
task = cls(config=config, job_id=job_id, os=os)
task = cls(config=config, job_id=job_id, os=os, user_info=user_info)
task.save()
Webhook.send_event(
WebhookEventTaskCreated(
job_id=task.job_id, task_id=task.task_id, config=config
job_id=task.job_id,
task_id=task.task_id,
config=config,
user_info=user_info,
)
)
return task
@ -183,7 +188,9 @@ class Task(BASE_TASK, ORMMixin):
self.state = TaskState.stopping
self.save()
Webhook.send_event(
WebhookEventTaskStopped(job_id=self.job_id, task_id=self.task_id)
WebhookEventTaskStopped(
job_id=self.job_id, task_id=self.task_id, user_info=self.user_info
)
)
def mark_failed(self, error: Error) -> None:
@ -199,7 +206,10 @@ class Task(BASE_TASK, ORMMixin):
Webhook.send_event(
WebhookEventTaskFailed(
job_id=self.job_id, task_id=self.task_id, error=error
job_id=self.job_id,
task_id=self.task_id,
error=error,
user_info=self.user_info,
)
)

View File

@ -0,0 +1,42 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from uuid import UUID
import azure.functions as func
import jwt
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error, Result, UserInfo
def parse_jwt_token(request: func.HttpRequest) -> Result[UserInfo]:
""" Obtains the Access Token from the Authorization Header """
auth: str = request.headers.get("Authorization", None)
if not auth:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["Authorization header is expected"]
)
parts = auth.split()
if len(parts) != 2:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["Invalid authorization header"]
)
if parts[0].lower() != "bearer":
return Error(
code=ErrorCode.INVALID_REQUEST,
errors=["Authorization header must start with Bearer"],
)
# This token has already been verified by the azure authentication layer
token = jwt.decode(parts[1], verify=False)
application_id = UUID(token["appid"])
object_id = UUID(token["oid"]) if "oid" in token else None
upn = token.get("upn")
return UserInfo(application_id=application_id, object_id=object_id, upn=upn)

View File

@ -231,7 +231,7 @@ def build_message(
WebhookMessage(
webhook_id=webhook_id, event_id=event_id, event_type=event_type, event=event
)
.json(sort_keys=True)
.json(sort_keys=True, exclude_none=True)
.encode()
)
digest = None