Merge pull request #1992 from GNS3/secure-websocket-endpoints

Secure websocket endpoints
This commit is contained in:
Jeremy Grossmann 2021-11-01 17:10:41 +10:30 committed by GitHub
commit 55e50dae4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 87 additions and 26 deletions

View File

@ -16,8 +16,9 @@
import re
from fastapi import Request, Depends, HTTPException, status
from fastapi import Request, Query, Depends, HTTPException, WebSocket, status
from fastapi.security import OAuth2PasswordBearer
from typing import Optional
from gns3server import schemas
from gns3server.db.repositories.users import UsersRepository
@ -76,3 +77,53 @@ async def get_current_active_user(
)
return current_user
async def get_current_active_user_from_websocket(
websocket: WebSocket,
token: str = Query(...),
user_repo: UsersRepository = Depends(get_repository(UsersRepository)),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> Optional[schemas.User]:
await websocket.accept()
try:
username = auth_service.get_username_from_token(token)
user = await user_repo.get_user_by_username(username)
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Could not validate credentials for '{username}'"
)
# Super admin is always authorized
if user.is_superadmin:
return user
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"'{username}' is not an active user"
)
# remove the prefix (e.g. "/v3") from URL path
path = re.sub(r"^/v[0-9]", "", websocket.url.path)
# there are no HTTP methods for web sockets, assuming "GET"...
authorized = await rbac_repo.check_user_is_authorized(user.user_id, "GET", path)
if not authorized:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"User is not authorized '{user.user_id}' on '{path}'",
headers={"WWW-Authenticate": "Bearer"},
)
return user
except HTTPException as e:
websocket_error = {"action": "log.error", "event": {"message": f"Could not authenticate while connecting to "
f"WebSocket: {e.detail}"}}
await websocket.send_json(websocket_error)
await websocket.close(code=1008)

View File

@ -15,13 +15,14 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from typing import Callable, Type
from fastapi import Depends, Request
from fastapi import Depends
from starlette.requests import HTTPConnection
from sqlalchemy.ext.asyncio import AsyncSession
from gns3server.db.repositories.base import BaseRepository
async def get_db_session(request: Request) -> AsyncSession:
async def get_db_session(request: HTTPConnection) -> AsyncSession:
session = AsyncSession(request.app.state._db_engine, expire_on_commit=False)
try:

View File

@ -18,14 +18,14 @@
API routes for controller notifications.
"""
from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect, HTTPException
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
from fastapi.responses import StreamingResponse
from websockets.exceptions import ConnectionClosed, WebSocketException
from gns3server.services import auth_service
from gns3server.controller import Controller
from gns3server import schemas
from .dependencies.authentication import get_current_active_user
from .dependencies.authentication import get_current_active_user, get_current_active_user_from_websocket
import logging
@ -35,7 +35,7 @@ router = APIRouter()
@router.get("", dependencies=[Depends(get_current_active_user)])
async def http_notification() -> StreamingResponse:
async def controller_http_notifications() -> StreamingResponse:
"""
Receive controller notifications about the controller from HTTP stream.
"""
@ -50,19 +50,16 @@ async def http_notification() -> StreamingResponse:
@router.websocket("/ws")
async def notification_ws(websocket: WebSocket, token: str = Query(None)) -> None:
async def controller_ws_notifications(
websocket: WebSocket,
current_user: schemas.User = Depends(get_current_active_user_from_websocket)
) -> None:
"""
Receive project notifications about the controller from WebSocket.
"""
await websocket.accept()
if token:
try:
username = auth_service.get_username_from_token(token)
except HTTPException:
log.error("Invalid token received")
await websocket.close(code=1008)
return
if current_user is None:
return
log.info(f"New client {websocket.client.host}:{websocket.client.port} has connected to controller WebSocket")
try:

View File

@ -51,7 +51,7 @@ from gns3server.db.repositories.rbac import RbacRepository
from gns3server.db.repositories.templates import TemplatesRepository
from gns3server.services.templates import TemplatesService
from .dependencies.authentication import get_current_active_user
from .dependencies.authentication import get_current_active_user, get_current_active_user_from_websocket
from .dependencies.database import get_repository
responses = {404: {"model": schemas.ErrorMessage, "description": "Could not find project"}}
@ -214,7 +214,7 @@ async def load_project(path: str = Body(..., embed=True)) -> schemas.Project:
@router.get("/{project_id}/notifications")
async def notification(project_id: UUID) -> StreamingResponse:
async def project_http_notifications(project_id: UUID) -> StreamingResponse:
"""
Receive project notifications about the controller from HTTP stream.
"""
@ -245,14 +245,20 @@ async def notification(project_id: UUID) -> StreamingResponse:
@router.websocket("/{project_id}/notifications/ws")
async def notification_ws(project_id: UUID, websocket: WebSocket) -> None:
async def project_ws_notifications(
project_id: UUID,
websocket: WebSocket,
current_user: schemas.User = Depends(get_current_active_user_from_websocket)
) -> None:
"""
Receive project notifications about the controller from WebSocket.
"""
if current_user is None:
return
controller = Controller.instance()
project = controller.get_project(str(project_id))
await websocket.accept()
log.info(f"New client has connected to the notification stream for project ID '{project.id}' (WebSocket method)")
try:

View File

@ -57,6 +57,12 @@ def create_default_roles(target, connection, **kw):
"path": "/",
"action": "ALLOW"
},
{
"description": "Allow to receive controller notifications",
"methods": ["GET"],
"path": "/notifications",
"action": "ALLOW"
},
{
"description": "Allow to create and list projects",
"methods": ["GET", "POST"],
@ -112,7 +118,7 @@ def add_permissions_to_role(target, connection, **kw):
role_id = result.first().role_id
# add minimum required paths to the "User" role
for path in ("/projects", "/templates", "/computes/*", "/symbols/*"):
for path in ("/notifications", "/projects", "/templates", "/computes/*", "/symbols/*"):
stmt = permissions_table.select().where(permissions_table.c.path == path)
result = connection.execute(stmt)
permission_id = result.first().permission_id

View File

@ -92,7 +92,7 @@ class TestPermissionRoutes:
response = await client.get(app.url_path_for("get_permissions"))
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 10 # 5 default permissions + 5 custom permissions
assert len(response.json()) == 11 # 6 default permissions + 5 custom permissions
async def test_update_permission(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession, project: Project) -> None:
@ -132,4 +132,4 @@ class TestPermissionRoutes:
rbac_repo = RbacRepository(db_session)
permissions_in_db = await rbac_repo.get_permissions()
assert len(permissions_in_db) == 9 # 5 default permissions + 4 custom permissions
assert len(permissions_in_db) == 10 # 6 default permissions + 4 custom permissions

View File

@ -142,7 +142,7 @@ class TestRolesPermissionsRoutes:
)
assert response.status_code == status.HTTP_204_NO_CONTENT
permissions = await rbac_repo.get_role_permissions(role_in_db.role_id)
assert len(permissions) == 5 # 4 default permissions + 1 custom permission
assert len(permissions) == 6 # 5 default permissions + 1 custom permission
async def test_get_role_permissions(
self,
@ -160,7 +160,7 @@ class TestRolesPermissionsRoutes:
role_id=role_in_db.role_id)
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 5 # 4 default permissions + 1 custom permission
assert len(response.json()) == 6 # 5 default permissions + 1 custom permission
async def test_remove_role_from_group(
self,
@ -182,4 +182,4 @@ class TestRolesPermissionsRoutes:
)
assert response.status_code == status.HTTP_204_NO_CONTENT
permissions = await rbac_repo.get_role_permissions(role_in_db.role_id)
assert len(permissions) == 4 # 4 default permissions
assert len(permissions) == 5 # 5 default permissions