mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-16 11:58:09 +00:00
Add User Info to created tasks (#303)
This PR makes user information from JWT tokens available as part of a Task. Included changes: * Renamed `verify_token` to `call_if_agent`, since this function is specific to agent token verification * Renames `is_authorized` to `is_agent`, since this function checks if the token is an agent * Adds support for unmanaged nodes in `is_agent` (see #133 for information) * Saves the user information from the JWT token on task create as part of `TaskConfig` Note, `TaskConfig` is what is provided to notification templates. This enables Github issues and ADO work items to tie back to the user that created the task. Note, while `upn` _usually_ means email for AAD user tokens. If we were going to make use of the email address, we should perform a graph lookup based on the `oid`, but we're not.
This commit is contained in:
@ -4,77 +4,50 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from typing import Callable, Union
|
||||
from typing import Callable
|
||||
from uuid import UUID
|
||||
|
||||
import azure.functions as func
|
||||
import jwt
|
||||
from memoization import cached
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error
|
||||
from pydantic import BaseModel
|
||||
from onefuzztypes.models import Error, UserInfo
|
||||
|
||||
from .azure.creds import get_scaleset_principal_id
|
||||
from .pools import Scaleset
|
||||
from .pools import Pool, Scaleset
|
||||
from .request import not_ok
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
application_id: UUID
|
||||
object_id: UUID
|
||||
|
||||
|
||||
def try_get_token_auth_header(request: func.HttpRequest) -> Union[Error, TokenData]:
|
||||
""" Obtains the Access Token from the Authorization Header """
|
||||
auth: str = request.headers.get("Authorization", None)
|
||||
if not auth:
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST, errors=["Authorization header is expected"]
|
||||
)
|
||||
parts = auth.split()
|
||||
|
||||
if parts[0].lower() != "bearer":
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
errors=["Authorization header must start with Bearer"],
|
||||
)
|
||||
|
||||
elif len(parts) == 1:
|
||||
return Error(code=ErrorCode.INVALID_REQUEST, errors=["Token not found"])
|
||||
|
||||
elif len(parts) > 2:
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
errors=["Authorization header must be Bearer token"],
|
||||
)
|
||||
|
||||
# This token has already been verified by the azure authentication layer
|
||||
token = jwt.decode(parts[1], verify=False)
|
||||
return TokenData(application_id=UUID(token["appid"]), object_id=UUID(token["oid"]))
|
||||
from .user_credentials import parse_jwt_token
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def is_authorized(token_data: TokenData) -> bool:
|
||||
# backward compatibility case for scalesets deployed before the migration
|
||||
# to user assigned managed id
|
||||
scalesets = Scaleset.get_by_object_id(token_data.object_id)
|
||||
if len(scalesets) > 0:
|
||||
def is_agent(token_data: UserInfo) -> bool:
|
||||
|
||||
if token_data.object_id:
|
||||
# backward compatibility case for scalesets deployed before the migration
|
||||
# to user assigned managed id
|
||||
scalesets = Scaleset.get_by_object_id(token_data.object_id)
|
||||
if len(scalesets) > 0:
|
||||
return True
|
||||
|
||||
# verify object_id against the user assigned managed identity
|
||||
principal_id: UUID = get_scaleset_principal_id()
|
||||
return principal_id == token_data.object_id
|
||||
|
||||
pools = Pool.search(query={"client_id": [token_data.application_id]})
|
||||
if len(pools) > 0:
|
||||
return True
|
||||
|
||||
# verify object_id against the user assigned managed identity
|
||||
principal_id: UUID = get_scaleset_principal_id()
|
||||
return principal_id == token_data.object_id
|
||||
return False
|
||||
|
||||
|
||||
def verify_token(
|
||||
def call_if_agent(
|
||||
req: func.HttpRequest, method: Callable[[func.HttpRequest], func.HttpResponse]
|
||||
) -> func.HttpResponse:
|
||||
token = try_get_token_auth_header(req)
|
||||
|
||||
token = parse_jwt_token(req)
|
||||
if isinstance(token, Error):
|
||||
return not_ok(token, status_code=401, context="token verification")
|
||||
|
||||
if not is_authorized(token):
|
||||
if not is_agent(token):
|
||||
logging.error(
|
||||
"rejecting token url:%s token:%s body:%s",
|
||||
repr(req.url),
|
||||
|
@ -11,7 +11,7 @@ from uuid import UUID
|
||||
from onefuzztypes.enums import ErrorCode, TaskState
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.models import Task as BASE_TASK
|
||||
from onefuzztypes.models import TaskConfig, TaskVm
|
||||
from onefuzztypes.models import TaskConfig, TaskVm, UserInfo
|
||||
from onefuzztypes.webhooks import (
|
||||
WebhookEventTaskCreated,
|
||||
WebhookEventTaskFailed,
|
||||
@ -42,7 +42,9 @@ class Task(BASE_TASK, ORMMixin):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def create(cls, config: TaskConfig, job_id: UUID) -> Union["Task", Error]:
|
||||
def create(
|
||||
cls, config: TaskConfig, job_id: UUID, user_info: UserInfo
|
||||
) -> Union["Task", Error]:
|
||||
if config.vm:
|
||||
os = get_os(config.vm.region, config.vm.image)
|
||||
elif config.pool:
|
||||
@ -52,11 +54,14 @@ class Task(BASE_TASK, ORMMixin):
|
||||
os = pool.os
|
||||
else:
|
||||
raise Exception("task must have vm or pool")
|
||||
task = cls(config=config, job_id=job_id, os=os)
|
||||
task = cls(config=config, job_id=job_id, os=os, user_info=user_info)
|
||||
task.save()
|
||||
Webhook.send_event(
|
||||
WebhookEventTaskCreated(
|
||||
job_id=task.job_id, task_id=task.task_id, config=config
|
||||
job_id=task.job_id,
|
||||
task_id=task.task_id,
|
||||
config=config,
|
||||
user_info=user_info,
|
||||
)
|
||||
)
|
||||
return task
|
||||
@ -183,7 +188,9 @@ class Task(BASE_TASK, ORMMixin):
|
||||
self.state = TaskState.stopping
|
||||
self.save()
|
||||
Webhook.send_event(
|
||||
WebhookEventTaskStopped(job_id=self.job_id, task_id=self.task_id)
|
||||
WebhookEventTaskStopped(
|
||||
job_id=self.job_id, task_id=self.task_id, user_info=self.user_info
|
||||
)
|
||||
)
|
||||
|
||||
def mark_failed(self, error: Error) -> None:
|
||||
@ -199,7 +206,10 @@ class Task(BASE_TASK, ORMMixin):
|
||||
|
||||
Webhook.send_event(
|
||||
WebhookEventTaskFailed(
|
||||
job_id=self.job_id, task_id=self.task_id, error=error
|
||||
job_id=self.job_id,
|
||||
task_id=self.task_id,
|
||||
error=error,
|
||||
user_info=self.user_info,
|
||||
)
|
||||
)
|
||||
|
||||
|
42
src/api-service/__app__/onefuzzlib/user_credentials.py
Normal file
42
src/api-service/__app__/onefuzzlib/user_credentials.py
Normal file
@ -0,0 +1,42 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
import azure.functions as func
|
||||
import jwt
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error, Result, UserInfo
|
||||
|
||||
|
||||
def parse_jwt_token(request: func.HttpRequest) -> Result[UserInfo]:
|
||||
""" Obtains the Access Token from the Authorization Header """
|
||||
|
||||
auth: str = request.headers.get("Authorization", None)
|
||||
if not auth:
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST, errors=["Authorization header is expected"]
|
||||
)
|
||||
|
||||
parts = auth.split()
|
||||
|
||||
if len(parts) != 2:
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST, errors=["Invalid authorization header"]
|
||||
)
|
||||
|
||||
if parts[0].lower() != "bearer":
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
errors=["Authorization header must start with Bearer"],
|
||||
)
|
||||
|
||||
# This token has already been verified by the azure authentication layer
|
||||
token = jwt.decode(parts[1], verify=False)
|
||||
|
||||
application_id = UUID(token["appid"])
|
||||
object_id = UUID(token["oid"]) if "oid" in token else None
|
||||
upn = token.get("upn")
|
||||
return UserInfo(application_id=application_id, object_id=object_id, upn=upn)
|
@ -231,7 +231,7 @@ def build_message(
|
||||
WebhookMessage(
|
||||
webhook_id=webhook_id, event_id=event_id, event_type=event_type, event=event
|
||||
)
|
||||
.json(sort_keys=True)
|
||||
.json(sort_keys=True, exclude_none=True)
|
||||
.encode()
|
||||
)
|
||||
digest = None
|
||||
|
Reference in New Issue
Block a user