Add User Info to created tasks (#303)

This PR makes user information from JWT tokens available as part of a Task.

Included changes:
* Renamed `verify_token` to `call_if_agent`, since this function is specific to agent token verification
* Renames `is_authorized` to `is_agent`, since this function checks if the token is an agent
* Adds support for unmanaged nodes in `is_agent` (see #133 for information) 
* Saves the user information from the JWT token on task create as part of `TaskConfig`

Note, `TaskConfig` is what is provided to notification templates.  This enables Github issues and ADO work items to tie back to the user that created the task.

Note, while `upn` _usually_ means email for AAD user tokens.  If we were going to make use of the email address, we should perform a graph lookup based on the `oid`, but we're not.
This commit is contained in:
bmc-msft
2020-11-13 06:50:52 -05:00
committed by GitHub
parent 31f099d3d4
commit beea318968
13 changed files with 254 additions and 71 deletions

View File

@ -54,6 +54,11 @@ Each event will be submitted via HTTP POST to the user provided URL.
}
],
"tags": {}
},
"user_info": {
"application_id": "00000000-0000-0000-0000-000000000000",
"object_id": "00000000-0000-0000-0000-000000000000",
"upn": "example@contoso.com"
}
}
```
@ -77,6 +82,9 @@ Each event will be submitted via HTTP POST to the user provided URL.
},
"config": {
"$ref": "#/definitions/TaskConfig"
},
"user_info": {
"$ref": "#/definitions/UserInfo"
}
},
"required": [
@ -401,6 +409,29 @@ Each event will be submitted via HTTP POST to the user provided URL.
"containers",
"tags"
]
},
"UserInfo": {
"title": "UserInfo",
"type": "object",
"properties": {
"application_id": {
"title": "Application Id",
"type": "string",
"format": "uuid"
},
"object_id": {
"title": "Object Id",
"type": "string",
"format": "uuid"
},
"upn": {
"title": "Upn",
"type": "string"
}
},
"required": [
"application_id"
]
}
}
}
@ -413,7 +444,12 @@ Each event will be submitted via HTTP POST to the user provided URL.
```json
{
"job_id": "00000000-0000-0000-0000-000000000000",
"task_id": "00000000-0000-0000-0000-000000000000"
"task_id": "00000000-0000-0000-0000-000000000000",
"user_info": {
"application_id": "00000000-0000-0000-0000-000000000000",
"object_id": "00000000-0000-0000-0000-000000000000",
"upn": "example@contoso.com"
}
}
```
@ -433,12 +469,40 @@ Each event will be submitted via HTTP POST to the user provided URL.
"title": "Task Id",
"type": "string",
"format": "uuid"
},
"user_info": {
"$ref": "#/definitions/UserInfo"
}
},
"required": [
"job_id",
"task_id"
]
],
"definitions": {
"UserInfo": {
"title": "UserInfo",
"type": "object",
"properties": {
"application_id": {
"title": "Application Id",
"type": "string",
"format": "uuid"
},
"object_id": {
"title": "Object Id",
"type": "string",
"format": "uuid"
},
"upn": {
"title": "Upn",
"type": "string"
}
},
"required": [
"application_id"
]
}
}
}
```
@ -455,6 +519,11 @@ Each event will be submitted via HTTP POST to the user provided URL.
"errors": [
"example error message"
]
},
"user_info": {
"application_id": "00000000-0000-0000-0000-000000000000",
"object_id": "00000000-0000-0000-0000-000000000000",
"upn": "example@contoso.com"
}
}
```
@ -478,6 +547,9 @@ Each event will be submitted via HTTP POST to the user provided URL.
},
"error": {
"$ref": "#/definitions/Error"
},
"user_info": {
"$ref": "#/definitions/UserInfo"
}
},
"required": [
@ -531,6 +603,29 @@ Each event will be submitted via HTTP POST to the user provided URL.
"code",
"errors"
]
},
"UserInfo": {
"title": "UserInfo",
"type": "object",
"properties": {
"application_id": {
"title": "Application Id",
"type": "string",
"format": "uuid"
},
"object_id": {
"title": "Object Id",
"type": "string",
"format": "uuid"
},
"upn": {
"title": "Upn",
"type": "string"
}
},
"required": [
"application_id"
]
}
}
}
@ -933,6 +1028,29 @@ Each event will be submitted via HTTP POST to the user provided URL.
"tags"
]
},
"UserInfo": {
"title": "UserInfo",
"type": "object",
"properties": {
"application_id": {
"title": "Application Id",
"type": "string",
"format": "uuid"
},
"object_id": {
"title": "Object Id",
"type": "string",
"format": "uuid"
},
"upn": {
"title": "Upn",
"type": "string"
}
},
"required": [
"application_id"
]
},
"WebhookEventTaskCreated": {
"title": "WebhookEventTaskCreated",
"type": "object",
@ -949,6 +1067,9 @@ Each event will be submitted via HTTP POST to the user provided URL.
},
"config": {
"$ref": "#/definitions/TaskConfig"
},
"user_info": {
"$ref": "#/definitions/UserInfo"
}
},
"required": [
@ -970,6 +1091,9 @@ Each event will be submitted via HTTP POST to the user provided URL.
"title": "Task Id",
"type": "string",
"format": "uuid"
},
"user_info": {
"$ref": "#/definitions/UserInfo"
}
},
"required": [
@ -1039,6 +1163,9 @@ Each event will be submitted via HTTP POST to the user provided URL.
},
"error": {
"$ref": "#/definitions/Error"
},
"user_info": {
"$ref": "#/definitions/UserInfo"
}
},
"required": [

View File

@ -9,7 +9,7 @@ from onefuzztypes.models import Error
from onefuzztypes.requests import CanScheduleRequest
from onefuzztypes.responses import CanSchedule
from ..onefuzzlib.agent_authorization import verify_token
from ..onefuzzlib.agent_authorization import call_if_agent
from ..onefuzzlib.pools import Node
from ..onefuzzlib.request import not_ok, ok, parse_request
from ..onefuzzlib.tasks.main import Task
@ -48,4 +48,4 @@ def main(req: func.HttpRequest) -> func.HttpResponse:
else:
raise Exception("invalid method")
return verify_token(req, m)
return call_if_agent(req, m)

View File

@ -8,7 +8,7 @@ from onefuzztypes.models import Error, NodeCommandEnvelope
from onefuzztypes.requests import NodeCommandDelete, NodeCommandGet
from onefuzztypes.responses import BoolResult, PendingNodeCommand
from ..onefuzzlib.agent_authorization import verify_token
from ..onefuzzlib.agent_authorization import call_if_agent
from ..onefuzzlib.pools import NodeMessage
from ..onefuzzlib.request import not_ok, ok, parse_request
@ -50,4 +50,4 @@ def main(req: func.HttpRequest) -> func.HttpResponse:
else:
raise Exception("invalid method")
return verify_token(req, m)
return call_if_agent(req, m)

View File

@ -16,7 +16,7 @@ from onefuzztypes.models import (
)
from onefuzztypes.responses import BoolResult
from ..onefuzzlib.agent_authorization import verify_token
from ..onefuzzlib.agent_authorization import call_if_agent
from ..onefuzzlib.agent_events import on_state_update, on_worker_event
from ..onefuzzlib.request import not_ok, ok, parse_request
@ -72,4 +72,4 @@ def post(req: func.HttpRequest) -> func.HttpResponse:
def main(req: func.HttpRequest) -> func.HttpResponse:
return verify_token(req, post)
return call_if_agent(req, post)

View File

@ -12,7 +12,7 @@ from onefuzztypes.models import Error
from onefuzztypes.requests import AgentRegistrationGet, AgentRegistrationPost
from onefuzztypes.responses import AgentRegistration
from ..onefuzzlib.agent_authorization import verify_token
from ..onefuzzlib.agent_authorization import call_if_agent
from ..onefuzzlib.azure.creds import get_fuzz_storage, get_instance_url
from ..onefuzzlib.azure.queue import get_queue_sas
from ..onefuzzlib.pools import Node, NodeMessage, NodeTasks, Pool
@ -122,4 +122,4 @@ def main(req: func.HttpRequest) -> func.HttpResponse:
else:
raise Exception("invalid method")
return verify_token(req, m)
return call_if_agent(req, m)

View File

@ -4,77 +4,50 @@
# Licensed under the MIT License.
import logging
from typing import Callable, Union
from typing import Callable
from uuid import UUID
import azure.functions as func
import jwt
from memoization import cached
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error
from pydantic import BaseModel
from onefuzztypes.models import Error, UserInfo
from .azure.creds import get_scaleset_principal_id
from .pools import Scaleset
from .pools import Pool, Scaleset
from .request import not_ok
class TokenData(BaseModel):
application_id: UUID
object_id: UUID
def try_get_token_auth_header(request: func.HttpRequest) -> Union[Error, TokenData]:
""" Obtains the Access Token from the Authorization Header """
auth: str = request.headers.get("Authorization", None)
if not auth:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["Authorization header is expected"]
)
parts = auth.split()
if parts[0].lower() != "bearer":
return Error(
code=ErrorCode.INVALID_REQUEST,
errors=["Authorization header must start with Bearer"],
)
elif len(parts) == 1:
return Error(code=ErrorCode.INVALID_REQUEST, errors=["Token not found"])
elif len(parts) > 2:
return Error(
code=ErrorCode.INVALID_REQUEST,
errors=["Authorization header must be Bearer token"],
)
# This token has already been verified by the azure authentication layer
token = jwt.decode(parts[1], verify=False)
return TokenData(application_id=UUID(token["appid"]), object_id=UUID(token["oid"]))
from .user_credentials import parse_jwt_token
@cached(ttl=60)
def is_authorized(token_data: TokenData) -> bool:
# backward compatibility case for scalesets deployed before the migration
# to user assigned managed id
scalesets = Scaleset.get_by_object_id(token_data.object_id)
if len(scalesets) > 0:
def is_agent(token_data: UserInfo) -> bool:
if token_data.object_id:
# backward compatibility case for scalesets deployed before the migration
# to user assigned managed id
scalesets = Scaleset.get_by_object_id(token_data.object_id)
if len(scalesets) > 0:
return True
# verify object_id against the user assigned managed identity
principal_id: UUID = get_scaleset_principal_id()
return principal_id == token_data.object_id
pools = Pool.search(query={"client_id": [token_data.application_id]})
if len(pools) > 0:
return True
# verify object_id against the user assigned managed identity
principal_id: UUID = get_scaleset_principal_id()
return principal_id == token_data.object_id
return False
def verify_token(
def call_if_agent(
req: func.HttpRequest, method: Callable[[func.HttpRequest], func.HttpResponse]
) -> func.HttpResponse:
token = try_get_token_auth_header(req)
token = parse_jwt_token(req)
if isinstance(token, Error):
return not_ok(token, status_code=401, context="token verification")
if not is_authorized(token):
if not is_agent(token):
logging.error(
"rejecting token url:%s token:%s body:%s",
repr(req.url),

View File

@ -11,7 +11,7 @@ from uuid import UUID
from onefuzztypes.enums import ErrorCode, TaskState
from onefuzztypes.models import Error
from onefuzztypes.models import Task as BASE_TASK
from onefuzztypes.models import TaskConfig, TaskVm
from onefuzztypes.models import TaskConfig, TaskVm, UserInfo
from onefuzztypes.webhooks import (
WebhookEventTaskCreated,
WebhookEventTaskFailed,
@ -42,7 +42,9 @@ class Task(BASE_TASK, ORMMixin):
return True
@classmethod
def create(cls, config: TaskConfig, job_id: UUID) -> Union["Task", Error]:
def create(
cls, config: TaskConfig, job_id: UUID, user_info: UserInfo
) -> Union["Task", Error]:
if config.vm:
os = get_os(config.vm.region, config.vm.image)
elif config.pool:
@ -52,11 +54,14 @@ class Task(BASE_TASK, ORMMixin):
os = pool.os
else:
raise Exception("task must have vm or pool")
task = cls(config=config, job_id=job_id, os=os)
task = cls(config=config, job_id=job_id, os=os, user_info=user_info)
task.save()
Webhook.send_event(
WebhookEventTaskCreated(
job_id=task.job_id, task_id=task.task_id, config=config
job_id=task.job_id,
task_id=task.task_id,
config=config,
user_info=user_info,
)
)
return task
@ -183,7 +188,9 @@ class Task(BASE_TASK, ORMMixin):
self.state = TaskState.stopping
self.save()
Webhook.send_event(
WebhookEventTaskStopped(job_id=self.job_id, task_id=self.task_id)
WebhookEventTaskStopped(
job_id=self.job_id, task_id=self.task_id, user_info=self.user_info
)
)
def mark_failed(self, error: Error) -> None:
@ -199,7 +206,10 @@ class Task(BASE_TASK, ORMMixin):
Webhook.send_event(
WebhookEventTaskFailed(
job_id=self.job_id, task_id=self.task_id, error=error
job_id=self.job_id,
task_id=self.task_id,
error=error,
user_info=self.user_info,
)
)

View File

@ -0,0 +1,42 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from uuid import UUID
import azure.functions as func
import jwt
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error, Result, UserInfo
def parse_jwt_token(request: func.HttpRequest) -> Result[UserInfo]:
""" Obtains the Access Token from the Authorization Header """
auth: str = request.headers.get("Authorization", None)
if not auth:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["Authorization header is expected"]
)
parts = auth.split()
if len(parts) != 2:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["Invalid authorization header"]
)
if parts[0].lower() != "bearer":
return Error(
code=ErrorCode.INVALID_REQUEST,
errors=["Authorization header must start with Bearer"],
)
# This token has already been verified by the azure authentication layer
token = jwt.decode(parts[1], verify=False)
application_id = UUID(token["appid"])
object_id = UUID(token["oid"]) if "oid" in token else None
upn = token.get("upn")
return UserInfo(application_id=application_id, object_id=object_id, upn=upn)

View File

@ -231,7 +231,7 @@ def build_message(
WebhookMessage(
webhook_id=webhook_id, event_id=event_id, event_type=event_type, event=event
)
.json(sort_keys=True)
.json(sort_keys=True, exclude_none=True)
.encode()
)
digest = None

View File

@ -16,6 +16,7 @@ from ..onefuzzlib.request import not_ok, ok, parse_request
from ..onefuzzlib.task_event import TaskEvent
from ..onefuzzlib.tasks.config import TaskConfigError, check_config
from ..onefuzzlib.tasks.main import Task
from ..onefuzzlib.user_credentials import parse_jwt_token
def post(req: func.HttpRequest) -> func.HttpResponse:
@ -23,6 +24,10 @@ def post(req: func.HttpRequest) -> func.HttpResponse:
if isinstance(request, Error):
return not_ok(request, context="task create")
user_info = parse_jwt_token(req)
if isinstance(user_info, Error):
return not_ok(user_info, context="task create")
try:
check_config(request)
except TaskConfigError as err:
@ -56,7 +61,7 @@ def post(req: func.HttpRequest) -> func.HttpResponse:
if isinstance(prereq, Error):
return not_ok(prereq, context="task create prerequisite")
task = Task.create(config=request, job_id=request.job_id)
task = Task.create(config=request, job_id=request.job_id, user_info=user_info)
if isinstance(task, Error):
return not_ok(task, context="task create invalid pool")
return ok(task)

View File

@ -7,7 +7,7 @@ from typing import Optional
from uuid import UUID
import json
from onefuzztypes.enums import TaskType, ContainerType, ErrorCode
from onefuzztypes.models import TaskConfig, TaskDetails, TaskContainers, Error
from onefuzztypes.models import TaskConfig, TaskDetails, TaskContainers, Error, UserInfo
from onefuzztypes.webhooks import (
WebhookMessage,
WebhookEventPing,
@ -32,12 +32,23 @@ def main():
examples = {
WebhookEventType.ping: WebhookEventPing(ping_id=UUID(int=0)),
WebhookEventType.task_stopped: WebhookEventTaskStopped(
job_id=UUID(int=0), task_id=UUID(int=0)
job_id=UUID(int=0),
task_id=UUID(int=0),
user_info=UserInfo(
application_id=UUID(int=0),
object_id=UUID(int=0),
upn="example@contoso.com",
),
),
WebhookEventType.task_failed: WebhookEventTaskFailed(
job_id=UUID(int=0),
task_id=UUID(int=0),
error=Error(code=ErrorCode.TASK_FAILED, errors=["example error message"]),
user_info=UserInfo(
application_id=UUID(int=0),
object_id=UUID(int=0),
upn="example@contoso.com",
),
),
WebhookEventType.task_created: WebhookEventTaskCreated(
job_id=UUID(int=0),
@ -58,6 +69,11 @@ def main():
],
tags={},
),
user_info=UserInfo(
application_id=UUID(int=0),
object_id=UUID(int=0),
upn="example@contoso.com",
),
),
}

View File

@ -35,6 +35,12 @@ from .enums import (
from .primitives import Container, PoolName, Region
class UserInfo(BaseModel):
application_id: UUID
object_id: Optional[UUID]
upn: Optional[str]
class EnumModel(BaseModel):
@root_validator(pre=True)
def exactly_one(cls: Any, values: Any) -> Any:
@ -704,3 +710,4 @@ class Task(BaseModel):
end_time: Optional[datetime]
events: Optional[List[TaskEventSummary]]
nodes: Optional[List[NodeAssignment]]
user_info: Optional[UserInfo]

View File

@ -4,25 +4,28 @@ from uuid import UUID, uuid4
from pydantic import AnyHttpUrl, BaseModel, Field
from .enums import WebhookEventType, WebhookMessageState
from .models import Error, TaskConfig
from .models import Error, TaskConfig, UserInfo
from .responses import BaseResponse
class WebhookEventTaskStopped(BaseModel):
job_id: UUID
task_id: UUID
user_info: Optional[UserInfo]
class WebhookEventTaskFailed(BaseModel):
job_id: UUID
task_id: UUID
error: Error
user_info: Optional[UserInfo]
class WebhookEventTaskCreated(BaseModel):
job_id: UUID
task_id: UUID
config: TaskConfig
user_info: Optional[UserInfo]
class WebhookEventPing(BaseResponse):