mirror of
https://github.com/GNS3/gns3-server.git
synced 2025-01-22 04:18:07 +00:00
Merge pull request #1992 from GNS3/secure-websocket-endpoints
Secure websocket endpoints
This commit is contained in:
commit
55e50dae4b
@ -16,8 +16,9 @@
|
||||
|
||||
import re
|
||||
|
||||
from fastapi import Request, Depends, HTTPException, status
|
||||
from fastapi import Request, Query, Depends, HTTPException, WebSocket, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from typing import Optional
|
||||
|
||||
from gns3server import schemas
|
||||
from gns3server.db.repositories.users import UsersRepository
|
||||
@ -76,3 +77,53 @@ async def get_current_active_user(
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
|
||||
async def get_current_active_user_from_websocket(
|
||||
websocket: WebSocket,
|
||||
token: str = Query(...),
|
||||
user_repo: UsersRepository = Depends(get_repository(UsersRepository)),
|
||||
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
|
||||
) -> Optional[schemas.User]:
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
try:
|
||||
username = auth_service.get_username_from_token(token)
|
||||
user = await user_repo.get_user_by_username(username)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Could not validate credentials for '{username}'"
|
||||
)
|
||||
|
||||
# Super admin is always authorized
|
||||
if user.is_superadmin:
|
||||
return user
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"'{username}' is not an active user"
|
||||
)
|
||||
|
||||
# remove the prefix (e.g. "/v3") from URL path
|
||||
path = re.sub(r"^/v[0-9]", "", websocket.url.path)
|
||||
|
||||
# there are no HTTP methods for web sockets, assuming "GET"...
|
||||
authorized = await rbac_repo.check_user_is_authorized(user.user_id, "GET", path)
|
||||
if not authorized:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"User is not authorized '{user.user_id}' on '{path}'",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
except HTTPException as e:
|
||||
websocket_error = {"action": "log.error", "event": {"message": f"Could not authenticate while connecting to "
|
||||
f"WebSocket: {e.detail}"}}
|
||||
await websocket.send_json(websocket_error)
|
||||
await websocket.close(code=1008)
|
||||
|
@ -15,13 +15,14 @@
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from typing import Callable, Type
|
||||
from fastapi import Depends, Request
|
||||
from fastapi import Depends
|
||||
from starlette.requests import HTTPConnection
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from gns3server.db.repositories.base import BaseRepository
|
||||
|
||||
|
||||
async def get_db_session(request: Request) -> AsyncSession:
|
||||
async def get_db_session(request: HTTPConnection) -> AsyncSession:
|
||||
|
||||
session = AsyncSession(request.app.state._db_engine, expire_on_commit=False)
|
||||
try:
|
||||
|
@ -18,14 +18,14 @@
|
||||
API routes for controller notifications.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect, HTTPException
|
||||
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import StreamingResponse
|
||||
from websockets.exceptions import ConnectionClosed, WebSocketException
|
||||
|
||||
from gns3server.services import auth_service
|
||||
from gns3server.controller import Controller
|
||||
from gns3server import schemas
|
||||
|
||||
from .dependencies.authentication import get_current_active_user
|
||||
from .dependencies.authentication import get_current_active_user, get_current_active_user_from_websocket
|
||||
|
||||
import logging
|
||||
|
||||
@ -35,7 +35,7 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", dependencies=[Depends(get_current_active_user)])
|
||||
async def http_notification() -> StreamingResponse:
|
||||
async def controller_http_notifications() -> StreamingResponse:
|
||||
"""
|
||||
Receive controller notifications about the controller from HTTP stream.
|
||||
"""
|
||||
@ -50,19 +50,16 @@ async def http_notification() -> StreamingResponse:
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def notification_ws(websocket: WebSocket, token: str = Query(None)) -> None:
|
||||
async def controller_ws_notifications(
|
||||
websocket: WebSocket,
|
||||
current_user: schemas.User = Depends(get_current_active_user_from_websocket)
|
||||
) -> None:
|
||||
"""
|
||||
Receive project notifications about the controller from WebSocket.
|
||||
"""
|
||||
await websocket.accept()
|
||||
|
||||
if token:
|
||||
try:
|
||||
username = auth_service.get_username_from_token(token)
|
||||
except HTTPException:
|
||||
log.error("Invalid token received")
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
if current_user is None:
|
||||
return
|
||||
|
||||
log.info(f"New client {websocket.client.host}:{websocket.client.port} has connected to controller WebSocket")
|
||||
try:
|
||||
|
@ -51,7 +51,7 @@ from gns3server.db.repositories.rbac import RbacRepository
|
||||
from gns3server.db.repositories.templates import TemplatesRepository
|
||||
from gns3server.services.templates import TemplatesService
|
||||
|
||||
from .dependencies.authentication import get_current_active_user
|
||||
from .dependencies.authentication import get_current_active_user, get_current_active_user_from_websocket
|
||||
from .dependencies.database import get_repository
|
||||
|
||||
responses = {404: {"model": schemas.ErrorMessage, "description": "Could not find project"}}
|
||||
@ -214,7 +214,7 @@ async def load_project(path: str = Body(..., embed=True)) -> schemas.Project:
|
||||
|
||||
|
||||
@router.get("/{project_id}/notifications")
|
||||
async def notification(project_id: UUID) -> StreamingResponse:
|
||||
async def project_http_notifications(project_id: UUID) -> StreamingResponse:
|
||||
"""
|
||||
Receive project notifications about the controller from HTTP stream.
|
||||
"""
|
||||
@ -245,14 +245,20 @@ async def notification(project_id: UUID) -> StreamingResponse:
|
||||
|
||||
|
||||
@router.websocket("/{project_id}/notifications/ws")
|
||||
async def notification_ws(project_id: UUID, websocket: WebSocket) -> None:
|
||||
async def project_ws_notifications(
|
||||
project_id: UUID,
|
||||
websocket: WebSocket,
|
||||
current_user: schemas.User = Depends(get_current_active_user_from_websocket)
|
||||
) -> None:
|
||||
"""
|
||||
Receive project notifications about the controller from WebSocket.
|
||||
"""
|
||||
|
||||
if current_user is None:
|
||||
return
|
||||
|
||||
controller = Controller.instance()
|
||||
project = controller.get_project(str(project_id))
|
||||
await websocket.accept()
|
||||
|
||||
log.info(f"New client has connected to the notification stream for project ID '{project.id}' (WebSocket method)")
|
||||
try:
|
||||
|
@ -57,6 +57,12 @@ def create_default_roles(target, connection, **kw):
|
||||
"path": "/",
|
||||
"action": "ALLOW"
|
||||
},
|
||||
{
|
||||
"description": "Allow to receive controller notifications",
|
||||
"methods": ["GET"],
|
||||
"path": "/notifications",
|
||||
"action": "ALLOW"
|
||||
},
|
||||
{
|
||||
"description": "Allow to create and list projects",
|
||||
"methods": ["GET", "POST"],
|
||||
@ -112,7 +118,7 @@ def add_permissions_to_role(target, connection, **kw):
|
||||
role_id = result.first().role_id
|
||||
|
||||
# add minimum required paths to the "User" role
|
||||
for path in ("/projects", "/templates", "/computes/*", "/symbols/*"):
|
||||
for path in ("/notifications", "/projects", "/templates", "/computes/*", "/symbols/*"):
|
||||
stmt = permissions_table.select().where(permissions_table.c.path == path)
|
||||
result = connection.execute(stmt)
|
||||
permission_id = result.first().permission_id
|
||||
|
@ -92,7 +92,7 @@ class TestPermissionRoutes:
|
||||
|
||||
response = await client.get(app.url_path_for("get_permissions"))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.json()) == 10 # 5 default permissions + 5 custom permissions
|
||||
assert len(response.json()) == 11 # 6 default permissions + 5 custom permissions
|
||||
|
||||
async def test_update_permission(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession, project: Project) -> None:
|
||||
|
||||
@ -132,4 +132,4 @@ class TestPermissionRoutes:
|
||||
|
||||
rbac_repo = RbacRepository(db_session)
|
||||
permissions_in_db = await rbac_repo.get_permissions()
|
||||
assert len(permissions_in_db) == 9 # 5 default permissions + 4 custom permissions
|
||||
assert len(permissions_in_db) == 10 # 6 default permissions + 4 custom permissions
|
||||
|
@ -142,7 +142,7 @@ class TestRolesPermissionsRoutes:
|
||||
)
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
permissions = await rbac_repo.get_role_permissions(role_in_db.role_id)
|
||||
assert len(permissions) == 5 # 4 default permissions + 1 custom permission
|
||||
assert len(permissions) == 6 # 5 default permissions + 1 custom permission
|
||||
|
||||
async def test_get_role_permissions(
|
||||
self,
|
||||
@ -160,7 +160,7 @@ class TestRolesPermissionsRoutes:
|
||||
role_id=role_in_db.role_id)
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.json()) == 5 # 4 default permissions + 1 custom permission
|
||||
assert len(response.json()) == 6 # 5 default permissions + 1 custom permission
|
||||
|
||||
async def test_remove_role_from_group(
|
||||
self,
|
||||
@ -182,4 +182,4 @@ class TestRolesPermissionsRoutes:
|
||||
)
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
permissions = await rbac_repo.get_role_permissions(role_in_db.role_id)
|
||||
assert len(permissions) == 4 # 4 default permissions
|
||||
assert len(permissions) == 5 # 5 default permissions
|
||||
|
Loading…
Reference in New Issue
Block a user