Merge pull request #1906 from GNS3/rbac

RBAC support
This commit is contained in:
Jeremy Grossmann 2021-06-02 23:16:07 -07:00 committed by GitHub
commit 4dd3bc6a98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 2153 additions and 98 deletions

View File

@ -30,6 +30,8 @@ from . import symbols
from . import templates from . import templates
from . import users from . import users
from . import groups from . import groups
from . import roles
from . import permissions
from .dependencies.authentication import get_current_active_user from .dependencies.authentication import get_current_active_user
@ -46,31 +48,36 @@ router.include_router(
) )
router.include_router( router.include_router(
appliances.router, roles.router,
dependencies=[Depends(get_current_active_user)], dependencies=[Depends(get_current_active_user)],
prefix="/appliances", prefix="/roles",
tags=["Appliances"] tags=["Roles"]
) )
router.include_router( router.include_router(
computes.router, permissions.router,
dependencies=[Depends(get_current_active_user)], dependencies=[Depends(get_current_active_user)],
prefix="/computes", prefix="/permissions",
tags=["Computes"] tags=["Permissions"]
) )
router.include_router( router.include_router(
drawings.router, templates.router,
dependencies=[Depends(get_current_active_user)], dependencies=[Depends(get_current_active_user)],
prefix="/projects/{project_id}/drawings", tags=["Templates"]
tags=["Drawings"]) )
router.include_router( router.include_router(
gns3vm.router, projects.router,
deprecated=True,
dependencies=[Depends(get_current_active_user)], dependencies=[Depends(get_current_active_user)],
prefix="/gns3vm", prefix="/projects",
tags=["GNS3 VM"] tags=["Projects"])
router.include_router(
nodes.router,
dependencies=[Depends(get_current_active_user)],
prefix="/projects/{project_id}/nodes",
tags=["Nodes"]
) )
router.include_router( router.include_router(
@ -81,10 +88,28 @@ router.include_router(
) )
router.include_router( router.include_router(
nodes.router, drawings.router,
dependencies=[Depends(get_current_active_user)], dependencies=[Depends(get_current_active_user)],
prefix="/projects/{project_id}/nodes", prefix="/projects/{project_id}/drawings",
tags=["Nodes"] tags=["Drawings"])
router.include_router(
symbols.router,
dependencies=[Depends(get_current_active_user)],
prefix="/symbols", tags=["Symbols"]
)
router.include_router(
snapshots.router,
dependencies=[Depends(get_current_active_user)],
prefix="/projects/{project_id}/snapshots",
tags=["Snapshots"])
router.include_router(
computes.router,
dependencies=[Depends(get_current_active_user)],
prefix="/computes",
tags=["Computes"]
) )
router.include_router( router.include_router(
@ -94,25 +119,16 @@ router.include_router(
tags=["Notifications"]) tags=["Notifications"])
router.include_router( router.include_router(
projects.router, appliances.router,
dependencies=[Depends(get_current_active_user)], dependencies=[Depends(get_current_active_user)],
prefix="/projects", prefix="/appliances",
tags=["Projects"]) tags=["Appliances"]
router.include_router(
snapshots.router,
dependencies=[Depends(get_current_active_user)],
prefix="/projects/{project_id}/snapshots",
tags=["Snapshots"])
router.include_router(
symbols.router,
dependencies=[Depends(get_current_active_user)],
prefix="/symbols", tags=["Symbols"]
) )
router.include_router( router.include_router(
templates.router, gns3vm.router,
deprecated=True,
dependencies=[Depends(get_current_active_user)], dependencies=[Depends(get_current_active_user)],
tags=["Templates"] prefix="/gns3vm",
tags=["GNS3 VM"]
) )

View File

@ -14,14 +14,15 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import re
from fastapi import Depends, HTTPException, status from fastapi import Request, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from gns3server import schemas from gns3server import schemas
from gns3server.db.repositories.users import UsersRepository from gns3server.db.repositories.users import UsersRepository
from gns3server.db.repositories.rbac import RbacRepository
from gns3server.services import auth_service from gns3server.services import auth_service
from .database import get_repository from .database import get_repository
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v3/users/login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v3/users/login")
@ -42,7 +43,11 @@ async def get_user_from_token(
return user return user
async def get_current_active_user(current_user: schemas.User = Depends(get_user_from_token)) -> schemas.User: async def get_current_active_user(
request: Request,
current_user: schemas.User = Depends(get_user_from_token),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> schemas.User:
# Super admin is always authorized # Super admin is always authorized
if current_user.is_superadmin: if current_user.is_superadmin:
@ -54,4 +59,24 @@ async def get_current_active_user(current_user: schemas.User = Depends(get_user_
detail="Not an active user", detail="Not an active user",
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
# remove the prefix (e.g. "/v3") from URL path
match = re.search(r"^(/v[0-9]+).*", request.url.path)
if match:
path = request.url.path[len(match.group(1)):]
else:
path = request.url.path
# special case: always authorize access to the "/users/me" endpoint
if path == "/users/me":
return current_user
authorized = await rbac_repo.check_user_is_authorized(current_user.user_id, request.method, path)
if not authorized:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"User is not authorized '{current_user.user_id}' on {request.method} '{path}'",
headers={"WWW-Authenticate": "Bearer"},
)
return current_user return current_user

View File

@ -31,6 +31,7 @@ from gns3server.controller.controller_error import (
) )
from gns3server.db.repositories.users import UsersRepository from gns3server.db.repositories.users import UsersRepository
from gns3server.db.repositories.rbac import RbacRepository
from .dependencies.database import get_repository from .dependencies.database import get_repository
import logging import logging
@ -98,8 +99,8 @@ async def update_user_group(
if not user_group: if not user_group:
raise ControllerNotFoundError(f"User group '{user_group_id}' not found") raise ControllerNotFoundError(f"User group '{user_group_id}' not found")
if not user_group.is_updatable: if user_group.builtin:
raise ControllerForbiddenError(f"User group '{user_group_id}' cannot be updated") raise ControllerForbiddenError(f"Built-in user group '{user_group_id}' cannot be updated")
return await users_repo.update_user_group(user_group_id, user_group_update) return await users_repo.update_user_group(user_group_id, user_group_update)
@ -120,8 +121,8 @@ async def delete_user_group(
if not user_group: if not user_group:
raise ControllerNotFoundError(f"User group '{user_group_id}' not found") raise ControllerNotFoundError(f"User group '{user_group_id}' not found")
if not user_group.is_updatable: if user_group.builtin:
raise ControllerForbiddenError(f"User group '{user_group_id}' cannot be deleted") raise ControllerForbiddenError(f"Built-in user group '{user_group_id}' cannot be deleted")
success = await users_repo.delete_user_group(user_group_id) success = await users_repo.delete_user_group(user_group_id)
if not success: if not success:
@ -182,3 +183,61 @@ async def remove_member_from_group(
user_group = await users_repo.remove_member_from_user_group(user_group_id, user) user_group = await users_repo.remove_member_from_user_group(user_group_id, user)
if not user_group: if not user_group:
raise ControllerNotFoundError(f"User group '{user_group_id}' not found") raise ControllerNotFoundError(f"User group '{user_group_id}' not found")
@router.get("/{user_group_id}/roles", response_model=List[schemas.Role])
async def get_user_group_roles(
user_group_id: UUID,
users_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> List[schemas.Role]:
"""
Get all user group roles.
"""
return await users_repo.get_user_group_roles(user_group_id)
@router.put(
"/{user_group_id}/roles/{role_id}",
status_code=status.HTTP_204_NO_CONTENT
)
async def add_role_to_group(
user_group_id: UUID,
role_id: UUID,
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> None:
"""
Add role to an user group.
"""
role = await rbac_repo.get_role(role_id)
if not role:
raise ControllerNotFoundError(f"Role '{role_id}' not found")
user_group = await users_repo.add_role_to_user_group(user_group_id, role)
if not user_group:
raise ControllerNotFoundError(f"User group '{user_group_id}' not found")
@router.delete(
"/{user_group_id}/roles/{role_id}",
status_code=status.HTTP_204_NO_CONTENT
)
async def remove_role_from_group(
user_group_id: UUID,
role_id: UUID,
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> None:
"""
Remove role from an user group.
"""
role = await rbac_repo.get_role(role_id)
if not role:
raise ControllerNotFoundError(f"Role '{role_id}' not found")
user_group = await users_repo.remove_role_from_user_group(user_group_id, role)
if not user_group:
raise ControllerNotFoundError(f"User group '{user_group_id}' not found")

View File

@ -0,0 +1,117 @@
#!/usr/bin/env python
#
# Copyright (C) 2021 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
API routes for permissions.
"""
from fastapi import APIRouter, Depends, status
from uuid import UUID
from typing import List
from gns3server import schemas
from gns3server.controller.controller_error import (
ControllerBadRequestError,
ControllerNotFoundError,
ControllerForbiddenError,
)
from gns3server.db.repositories.rbac import RbacRepository
from .dependencies.database import get_repository
import logging
log = logging.getLogger(__name__)
router = APIRouter()
@router.get("", response_model=List[schemas.Permission])
async def get_permissions(
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> List[schemas.Permission]:
"""
Get all permissions.
"""
return await rbac_repo.get_permissions()
@router.post("", response_model=schemas.Permission, status_code=status.HTTP_201_CREATED)
async def create_permission(
permission_create: schemas.PermissionCreate,
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> schemas.Permission:
"""
Create a new permission.
"""
if await rbac_repo.check_permission_exists(permission_create):
raise ControllerBadRequestError(f"Permission '{permission_create.methods} {permission_create.path} "
f"{permission_create.action}' already exists")
return await rbac_repo.create_permission(permission_create)
@router.get("/{permission_id}", response_model=schemas.Permission)
async def get_permission(
permission_id: UUID,
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)),
) -> schemas.Permission:
"""
Get a permission.
"""
permission = await rbac_repo.get_permission(permission_id)
if not permission:
raise ControllerNotFoundError(f"Permission '{permission_id}' not found")
return permission
@router.put("/{permission_id}", response_model=schemas.Permission)
async def update_permission(
permission_id: UUID,
permission_update: schemas.PermissionUpdate,
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> schemas.Permission:
"""
Update a permission.
"""
permission = await rbac_repo.get_permission(permission_id)
if not permission:
raise ControllerNotFoundError(f"Permission '{permission_id}' not found")
return await rbac_repo.update_permission(permission_id, permission_update)
@router.delete("/{permission_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_permission(
permission_id: UUID,
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)),
) -> None:
"""
Delete a permission.
"""
permission = await rbac_repo.get_permission(permission_id)
if not permission:
raise ControllerNotFoundError(f"Permission '{permission_id}' not found")
success = await rbac_repo.delete_permission(permission_id)
if not success:
raise ControllerNotFoundError(f"Permission '{permission_id}' could not be deleted")

View File

@ -47,6 +47,10 @@ from gns3server.controller.export_project import export_project as export_contro
from gns3server.utils.asyncio import aiozipstream from gns3server.utils.asyncio import aiozipstream
from gns3server.utils.path import is_safe_path from gns3server.utils.path import is_safe_path
from gns3server.config import Config from gns3server.config import Config
from gns3server.db.repositories.rbac import RbacRepository
from .dependencies.authentication import get_current_active_user
from .dependencies.database import get_repository
responses = {404: {"model": schemas.ErrorMessage, "description": "Could not find project"}} responses = {404: {"model": schemas.ErrorMessage, "description": "Could not find project"}}
@ -66,13 +70,25 @@ CHUNK_SIZE = 1024 * 8 # 8KB
@router.get("", response_model=List[schemas.Project], response_model_exclude_unset=True) @router.get("", response_model=List[schemas.Project], response_model_exclude_unset=True)
def get_projects() -> List[schemas.Project]: async def get_projects(
current_user: schemas.User = Depends(get_current_active_user),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> List[schemas.Project]:
""" """
Return all projects. Return all projects.
""" """
controller = Controller.instance() controller = Controller.instance()
return [p.asdict() for p in controller.projects.values()] if current_user.is_superadmin:
return [p.asdict() for p in controller.projects.values()]
else:
user_projects = []
for project in controller.projects.values():
authorized = await rbac_repo.check_user_is_authorized(
current_user.user_id, "GET", f"/projects/{project.id}")
if authorized:
user_projects.append(project.asdict())
return user_projects
@router.post( @router.post(
@ -82,13 +98,18 @@ def get_projects() -> List[schemas.Project]:
response_model_exclude_unset=True, response_model_exclude_unset=True,
responses={409: {"model": schemas.ErrorMessage, "description": "Could not create project"}}, responses={409: {"model": schemas.ErrorMessage, "description": "Could not create project"}},
) )
async def create_project(project_data: schemas.ProjectCreate) -> schemas.Project: async def create_project(
project_data: schemas.ProjectCreate,
current_user: schemas.User = Depends(get_current_active_user),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> schemas.Project:
""" """
Create a new project. Create a new project.
""" """
controller = Controller.instance() controller = Controller.instance()
project = await controller.add_project(**jsonable_encoder(project_data, exclude_unset=True)) project = await controller.add_project(**jsonable_encoder(project_data, exclude_unset=True))
await rbac_repo.add_permission_to_user_with_path(current_user.user_id, f"/projects/{project.id}/*")
return project.asdict() return project.asdict()
@ -115,7 +136,10 @@ async def update_project(
@router.delete("/{project_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{project_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_project(project: Project = Depends(dep_project)) -> None: async def delete_project(
project: Project = Depends(dep_project),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> None:
""" """
Delete a project. Delete a project.
""" """
@ -123,6 +147,7 @@ async def delete_project(project: Project = Depends(dep_project)) -> None:
controller = Controller.instance() controller = Controller.instance()
await project.delete() await project.delete()
controller.remove_project(project) controller.remove_project(project)
await rbac_repo.delete_all_permissions_with_path(f"/projects/{project.id}")
@router.get("/{project_id}/stats") @router.get("/{project_id}/stats")
@ -344,7 +369,9 @@ async def import_project(
) )
async def duplicate_project( async def duplicate_project(
project_data: schemas.ProjectDuplicate, project_data: schemas.ProjectDuplicate,
project: Project = Depends(dep_project) project: Project = Depends(dep_project),
current_user: schemas.User = Depends(get_current_active_user),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> schemas.Project: ) -> schemas.Project:
""" """
Duplicate a project. Duplicate a project.
@ -361,6 +388,7 @@ async def duplicate_project(
new_project = await project.duplicate( new_project = await project.duplicate(
name=project_data.name, location=location, reset_mac_addresses=reset_mac_addresses name=project_data.name, location=location, reset_mac_addresses=reset_mac_addresses
) )
await rbac_repo.add_permission_to_user_with_path(current_user.user_id, f"/projects/{new_project.id}/*")
return new_project.asdict() return new_project.asdict()

View File

@ -0,0 +1,178 @@
#!/usr/bin/env python
#
# Copyright (C) 2021 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
API routes for roles.
"""
from fastapi import APIRouter, Depends, status
from uuid import UUID
from typing import List
from gns3server import schemas
from gns3server.controller.controller_error import (
ControllerBadRequestError,
ControllerNotFoundError,
ControllerForbiddenError,
)
from gns3server.db.repositories.rbac import RbacRepository
from .dependencies.database import get_repository
import logging
log = logging.getLogger(__name__)
router = APIRouter()
@router.get("", response_model=List[schemas.Role])
async def get_roles(
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> List[schemas.Role]:
"""
Get all roles.
"""
return await rbac_repo.get_roles()
@router.post("", response_model=schemas.Role, status_code=status.HTTP_201_CREATED)
async def create_role(
role_create: schemas.RoleCreate,
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> schemas.Role:
"""
Create a new role.
"""
if await rbac_repo.get_role_by_name(role_create.name):
raise ControllerBadRequestError(f"Role '{role_create.name}' already exists")
return await rbac_repo.create_role(role_create)
@router.get("/{role_id}", response_model=schemas.Role)
async def get_role(
role_id: UUID,
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)),
) -> schemas.Role:
"""
Get a role.
"""
role = await rbac_repo.get_role(role_id)
if not role:
raise ControllerNotFoundError(f"Role '{role_id}' not found")
return role
@router.put("/{role_id}", response_model=schemas.Role)
async def update_role(
role_id: UUID,
role_update: schemas.RoleUpdate,
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> schemas.Role:
"""
Update a role.
"""
role = await rbac_repo.get_role(role_id)
if not role:
raise ControllerNotFoundError(f"Role '{role_id}' not found")
if role.builtin:
raise ControllerForbiddenError(f"Built-in role '{role_id}' cannot be updated")
return await rbac_repo.update_role(role_id, role_update)
@router.delete("/{role_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_role(
role_id: UUID,
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)),
) -> None:
"""
Delete a role.
"""
role = await rbac_repo.get_role(role_id)
if not role:
raise ControllerNotFoundError(f"Role '{role_id}' not found")
if role.builtin:
raise ControllerForbiddenError(f"Built-in role '{role_id}' cannot be deleted")
success = await rbac_repo.delete_role(role_id)
if not success:
raise ControllerNotFoundError(f"Role '{role_id}' could not be deleted")
@router.get("/{role_id}/permissions", response_model=List[schemas.Permission])
async def get_role_permissions(
role_id: UUID,
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> List[schemas.Permission]:
"""
Get all role permissions.
"""
return await rbac_repo.get_role_permissions(role_id)
@router.put(
"/{role_id}/permissions/{permission_id}",
status_code=status.HTTP_204_NO_CONTENT
)
async def add_permission_to_role(
role_id: UUID,
permission_id: UUID,
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> None:
"""
Add a permission to a role.
"""
permission = await rbac_repo.get_permission(permission_id)
if not permission:
raise ControllerNotFoundError(f"Permission '{permission_id}' not found")
role = await rbac_repo.add_permission_to_role(role_id, permission)
if not role:
raise ControllerNotFoundError(f"Role '{role_id}' not found")
@router.delete(
"/{role_id}/permissions/{permission_id}",
status_code=status.HTTP_204_NO_CONTENT
)
async def remove_permission_from_role(
role_id: UUID,
permission_id: UUID,
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)),
) -> None:
"""
Remove member from an user group.
"""
permission = await rbac_repo.get_permission(permission_id)
if not permission:
raise ControllerNotFoundError(f"Permission '{permission_id}' not found")
role = await rbac_repo.remove_permission_from_role(role_id, permission)
if not role:
raise ControllerNotFoundError(f"Role '{role_id}' not found")

View File

@ -33,6 +33,9 @@ from gns3server import schemas
from gns3server.controller import Controller from gns3server.controller import Controller
from gns3server.db.repositories.templates import TemplatesRepository from gns3server.db.repositories.templates import TemplatesRepository
from gns3server.services.templates import TemplatesService from gns3server.services.templates import TemplatesService
from gns3server.db.repositories.rbac import RbacRepository
from .dependencies.authentication import get_current_active_user
from .dependencies.database import get_repository from .dependencies.database import get_repository
responses = {404: {"model": schemas.ErrorMessage, "description": "Could not find template"}} responses = {404: {"model": schemas.ErrorMessage, "description": "Could not find template"}}
@ -44,12 +47,17 @@ router = APIRouter(responses=responses)
async def create_template( async def create_template(
template_create: schemas.TemplateCreate, template_create: schemas.TemplateCreate,
templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)), templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)),
current_user: schemas.User = Depends(get_current_active_user),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> schemas.Template: ) -> schemas.Template:
""" """
Create a new template. Create a new template.
""" """
return await TemplatesService(templates_repo).create_template(template_create) template = await TemplatesService(templates_repo).create_template(template_create)
template_id = template.get("template_id")
await rbac_repo.add_permission_to_user_with_path(current_user.user_id, f"/templates/{template_id}/*")
return template
@router.get("/templates/{template_id}", response_model=schemas.Template, response_model_exclude_unset=True) @router.get("/templates/{template_id}", response_model=schemas.Template, response_model_exclude_unset=True)
@ -92,35 +100,58 @@ async def update_template(
status_code=status.HTTP_204_NO_CONTENT, status_code=status.HTTP_204_NO_CONTENT,
) )
async def delete_template( async def delete_template(
template_id: UUID, templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)) template_id: UUID,
templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> None: ) -> None:
""" """
Delete a template. Delete a template.
""" """
await TemplatesService(templates_repo).delete_template(template_id) await TemplatesService(templates_repo).delete_template(template_id)
await rbac_repo.delete_all_permissions_with_path(f"/templates/{template_id}")
@router.get("/templates", response_model=List[schemas.Template], response_model_exclude_unset=True) @router.get("/templates", response_model=List[schemas.Template], response_model_exclude_unset=True)
async def get_templates( async def get_templates(
templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)), templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)),
current_user: schemas.User = Depends(get_current_active_user),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> List[schemas.Template]: ) -> List[schemas.Template]:
""" """
Return all templates. Return all templates.
""" """
return await TemplatesService(templates_repo).get_templates() templates = await TemplatesService(templates_repo).get_templates()
if current_user.is_superadmin:
return templates
else:
user_templates = []
for template in templates:
if template.get("builtin") is True:
user_templates.append(template)
continue
template_id = template.get("template_id")
authorized = await rbac_repo.check_user_is_authorized(
current_user.user_id, "GET", f"/templates/{template_id}")
if authorized:
user_templates.append(template)
return user_templates
@router.post("/templates/{template_id}/duplicate", response_model=schemas.Template, status_code=status.HTTP_201_CREATED) @router.post("/templates/{template_id}/duplicate", response_model=schemas.Template, status_code=status.HTTP_201_CREATED)
async def duplicate_template( async def duplicate_template(
template_id: UUID, templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)) template_id: UUID, templates_repo: TemplatesRepository = Depends(get_repository(TemplatesRepository)),
current_user: schemas.User = Depends(get_current_active_user),
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> schemas.Template: ) -> schemas.Template:
""" """
Duplicate a template. Duplicate a template.
""" """
return await TemplatesService(templates_repo).duplicate_template(template_id) template = await TemplatesService(templates_repo).duplicate_template(template_id)
await rbac_repo.add_permission_to_user_with_path(current_user.user_id, f"/templates/{template_id}/*")
return template
@router.post( @router.post(

View File

@ -32,6 +32,7 @@ from gns3server.controller.controller_error import (
) )
from gns3server.db.repositories.users import UsersRepository from gns3server.db.repositories.users import UsersRepository
from gns3server.db.repositories.rbac import RbacRepository
from gns3server.services import auth_service from gns3server.services import auth_service
from .dependencies.authentication import get_current_active_user from .dependencies.authentication import get_current_active_user
@ -88,6 +89,24 @@ async def authenticate(
return token return token
@router.get("/me", response_model=schemas.User)
async def get_logged_in_user(current_user: schemas.User = Depends(get_current_active_user)) -> schemas.User:
"""
Get the current active user.
"""
return current_user
@router.get("/me", response_model=schemas.User)
async def get_logged_in_user(current_user: schemas.User = Depends(get_current_active_user)) -> schemas.User:
"""
Get the current active user.
"""
return current_user
@router.get("", response_model=List[schemas.User], dependencies=[Depends(get_current_active_user)]) @router.get("", response_model=List[schemas.User], dependencies=[Depends(get_current_active_user)])
async def get_users( async def get_users(
users_repo: UsersRepository = Depends(get_repository(UsersRepository)) users_repo: UsersRepository = Depends(get_repository(UsersRepository))
@ -178,15 +197,6 @@ async def delete_user(
raise ControllerNotFoundError(f"User '{user_id}' could not be deleted") raise ControllerNotFoundError(f"User '{user_id}' could not be deleted")
@router.get("/me/", response_model=schemas.User)
async def get_current_active_user(current_user: schemas.User = Depends(get_current_active_user)) -> schemas.User:
"""
Get the current active user.
"""
return current_user
@router.get( @router.get(
"/{user_id}/groups", "/{user_id}/groups",
dependencies=[Depends(get_current_active_user)], dependencies=[Depends(get_current_active_user)],
@ -201,3 +211,65 @@ async def get_user_memberships(
""" """
return await users_repo.get_user_memberships(user_id) return await users_repo.get_user_memberships(user_id)
@router.get(
"/{user_id}/permissions",
dependencies=[Depends(get_current_active_user)],
response_model=List[schemas.Permission]
)
async def get_user_permissions(
user_id: UUID,
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> List[schemas.Permission]:
"""
Get user permissions.
"""
return await rbac_repo.get_user_permissions(user_id)
@router.put(
"/{user_id}/permissions/{permission_id}",
dependencies=[Depends(get_current_active_user)],
status_code=status.HTTP_204_NO_CONTENT
)
async def add_permission_to_user(
user_id: UUID,
permission_id: UUID,
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository))
) -> None:
"""
Add a permission to an user.
"""
permission = await rbac_repo.get_permission(permission_id)
if not permission:
raise ControllerNotFoundError(f"Permission '{permission_id}' not found")
user = await rbac_repo.add_permission_to_user(user_id, permission)
if not user:
raise ControllerNotFoundError(f"User '{user_id}' not found")
@router.delete(
"/{user_id}/permissions/{permission_id}",
dependencies=[Depends(get_current_active_user)],
status_code=status.HTTP_204_NO_CONTENT
)
async def remove_permission_from_user(
user_id: UUID,
permission_id: UUID,
rbac_repo: RbacRepository = Depends(get_repository(RbacRepository)),
) -> None:
"""
Remove permission from an user.
"""
permission = await rbac_repo.get_permission(permission_id)
if not permission:
raise ControllerNotFoundError(f"Permission '{permission_id}' not found")
user = await rbac_repo.remove_permission_from_user(user_id, permission)
if not user:
raise ControllerNotFoundError(f"User '{user_id}' not found")

View File

@ -17,6 +17,8 @@
from .base import Base from .base import Base
from .users import User, UserGroup from .users import User, UserGroup
from .roles import Role
from .permissions import Permission
from .computes import Compute from .computes import Compute
from .templates import ( from .templates import (
Template, Template,

View File

@ -19,7 +19,7 @@ import uuid
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from sqlalchemy import Column, DateTime, func, inspect from sqlalchemy import Column, DateTime, func, inspect
from sqlalchemy.types import TypeDecorator, CHAR from sqlalchemy.types import TypeDecorator, CHAR, VARCHAR
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.ext.declarative import as_declarative from sqlalchemy.ext.declarative import as_declarative
@ -72,6 +72,37 @@ class GUID(TypeDecorator):
return value return value
class ListException(Exception):
pass
class ListType(TypeDecorator):
"""
Save/restore a Python list to/from a database column.
"""
impl = VARCHAR
cache_ok = True
def __init__(self, separator=',', *args, **kwargs):
self._separator = separator
super().__init__(*args, **kwargs)
def process_bind_param(self, value, dialect):
if value is not None:
if any(self._separator in str(item) for item in value):
raise ListException(f"List values cannot contain '{self._separator}'"
f"Please use a different separator.")
return self._separator.join(map(str, value))
def process_result_value(self, value, dialect):
if value is None:
return []
else:
return list(map(str, value.split(self._separator)))
class BaseTable(Base): class BaseTable(Base):
__abstract__ = True __abstract__ = True
@ -79,6 +110,8 @@ class BaseTable(Base):
created_at = Column(DateTime, server_default=func.current_timestamp()) created_at = Column(DateTime, server_default=func.current_timestamp())
updated_at = Column(DateTime, server_default=func.current_timestamp(), onupdate=func.current_timestamp()) updated_at = Column(DateTime, server_default=func.current_timestamp(), onupdate=func.current_timestamp())
__mapper_args__ = {"eager_defaults": True}
def generate_uuid(): def generate_uuid():
return str(uuid.uuid4()) return str(uuid.uuid4())

View File

@ -0,0 +1,122 @@
#!/usr/bin/env python
#
# Copyright (C) 2021 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from sqlalchemy import Table, Column, String, ForeignKey, event
from sqlalchemy.orm import relationship
from .base import Base, BaseTable, generate_uuid, GUID, ListType
import logging
log = logging.getLogger(__name__)
permission_role_link = Table(
"permissions_roles_link",
Base.metadata,
Column("permission_id", GUID, ForeignKey("permissions.permission_id", ondelete="CASCADE")),
Column("role_id", GUID, ForeignKey("roles.role_id", ondelete="CASCADE"))
)
class Permission(BaseTable):
__tablename__ = "permissions"
permission_id = Column(GUID, primary_key=True, default=generate_uuid)
description = Column(String)
methods = Column(ListType)
path = Column(String)
action = Column(String)
user_id = Column(GUID, ForeignKey('users.user_id', ondelete="CASCADE"))
roles = relationship("Role", secondary=permission_role_link, back_populates="permissions")
@event.listens_for(Permission.__table__, 'after_create')
def create_default_roles(target, connection, **kw):
default_permissions = [
{
"description": "Allow access to all endpoints",
"methods": ["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"],
"path": "/",
"action": "ALLOW"
},
{
"description": "Allow to create and list projects",
"methods": ["GET", "HEAD", "POST"],
"path": "/projects",
"action": "ALLOW"
},
{
"description": "Allow to create and list templates",
"methods": ["GET", "HEAD", "POST"],
"path": "/templates",
"action": "ALLOW"
},
{
"description": "Allow to list computes",
"methods": ["GET"],
"path": "/computes/*",
"action": "ALLOW"
},
{
"description": "Allow access to all symbol endpoints",
"methods": ["GET", "HEAD", "POST"],
"path": "/symbols/*",
"action": "ALLOW"
},
]
stmt = target.insert().values(default_permissions)
connection.execute(stmt)
connection.commit()
log.debug("The default permissions have been created in the database")
@event.listens_for(permission_role_link, 'after_create')
def add_permissions_to_role(target, connection, **kw):
from .roles import Role
roles_table = Role.__table__
stmt = roles_table.select().where(roles_table.c.name == "Administrator")
result = connection.execute(stmt)
role_id = result.first().role_id
permissions_table = Permission.__table__
stmt = permissions_table.select().where(permissions_table.c.path == "/")
result = connection.execute(stmt)
permission_id = result.first().permission_id
# add root path to the "Administrator" role
stmt = target.insert().values(permission_id=permission_id, role_id=role_id)
connection.execute(stmt)
stmt = roles_table.select().where(roles_table.c.name == "User")
result = connection.execute(stmt)
role_id = result.first().role_id
# add minimum required paths to the "User" role
for path in ("/projects", "/templates", "/computes/*", "/symbols/*"):
stmt = permissions_table.select().where(permissions_table.c.path == path)
result = connection.execute(stmt)
permission_id = result.first().permission_id
stmt = target.insert().values(permission_id=permission_id, role_id=role_id)
connection.execute(stmt)
connection.commit()

View File

@ -0,0 +1,81 @@
#!/usr/bin/env python
#
# Copyright (C) 2021 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from sqlalchemy import Table, Column, String, Boolean, ForeignKey, event
from sqlalchemy.orm import relationship
from .base import Base, BaseTable, generate_uuid, GUID
from .permissions import permission_role_link
import logging
log = logging.getLogger(__name__)
role_group_link = Table(
"roles_groups_link",
Base.metadata,
Column("role_id", GUID, ForeignKey("roles.role_id", ondelete="CASCADE")),
Column("user_group_id", GUID, ForeignKey("user_groups.user_group_id", ondelete="CASCADE"))
)
class Role(BaseTable):
__tablename__ = "roles"
role_id = Column(GUID, primary_key=True, default=generate_uuid)
name = Column(String)
description = Column(String)
builtin = Column(Boolean, default=False)
permissions = relationship("Permission", secondary=permission_role_link, back_populates="roles")
groups = relationship("UserGroup", secondary=role_group_link, back_populates="roles")
@event.listens_for(Role.__table__, 'after_create')
def create_default_roles(target, connection, **kw):
default_roles = [
{"name": "Administrator", "description": "Administrator role", "builtin": True},
{"name": "User", "description": "User role", "builtin": True},
]
stmt = target.insert().values(default_roles)
connection.execute(stmt)
connection.commit()
log.debug("The default roles have been created in the database")
@event.listens_for(role_group_link, 'after_create')
def add_admin_to_group(target, connection, **kw):
from .users import UserGroup
user_groups_table = UserGroup.__table__
roles_table = Role.__table__
# Add roles to built-in user groups
groups_to_roles = {"Administrators": "Administrator", "Users": "User"}
for user_group, role in groups_to_roles.items():
stmt = user_groups_table.select().where(user_groups_table.c.name == user_group)
result = connection.execute(stmt)
user_group_id = result.first().user_group_id
stmt = roles_table.select().where(roles_table.c.name == role)
result = connection.execute(stmt)
role_id = result.first().role_id
stmt = target.insert().values(role_id=role_id, user_group_id=user_group_id)
connection.execute(stmt)
connection.commit()

View File

@ -19,6 +19,8 @@ from sqlalchemy import Table, Boolean, Column, String, ForeignKey, event
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from .base import Base, BaseTable, generate_uuid, GUID from .base import Base, BaseTable, generate_uuid, GUID
from .roles import role_group_link
from gns3server.config import Config from gns3server.config import Config
from gns3server.services import auth_service from gns3server.services import auth_service
@ -26,11 +28,11 @@ import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
users_group_members = Table( user_group_link = Table(
"users_group_members", "users_groups_link",
Base.metadata, Base.metadata,
Column("user_id", GUID, ForeignKey("users.user_id", ondelete="CASCADE")), Column("user_id", GUID, ForeignKey("users.user_id", ondelete="CASCADE")),
Column("user_group_id", GUID, ForeignKey("users_group.user_group_id", ondelete="CASCADE")) Column("user_group_id", GUID, ForeignKey("user_groups.user_group_id", ondelete="CASCADE"))
) )
@ -45,7 +47,8 @@ class User(BaseTable):
hashed_password = Column(String) hashed_password = Column(String)
is_active = Column(Boolean, default=True) is_active = Column(Boolean, default=True)
is_superadmin = Column(Boolean, default=False) is_superadmin = Column(Boolean, default=False)
groups = relationship("UserGroup", secondary=users_group_members, back_populates="users") groups = relationship("UserGroup", secondary=user_group_link, back_populates="users")
permissions = relationship("Permission")
@event.listens_for(User.__table__, 'after_create') @event.listens_for(User.__table__, 'after_create')
@ -63,47 +66,47 @@ def create_default_super_admin(target, connection, **kw):
) )
connection.execute(stmt) connection.execute(stmt)
connection.commit() connection.commit()
log.info("The default super admin account has been created in the database") log.debug("The default super admin account has been created in the database")
class UserGroup(BaseTable): class UserGroup(BaseTable):
__tablename__ = "users_group" __tablename__ = "user_groups"
user_group_id = Column(GUID, primary_key=True, default=generate_uuid) user_group_id = Column(GUID, primary_key=True, default=generate_uuid)
name = Column(String, unique=True, index=True) name = Column(String, unique=True, index=True)
is_updatable = Column(Boolean, default=True) builtin = Column(Boolean, default=False)
users = relationship("User", secondary=users_group_members, back_populates="groups") users = relationship("User", secondary=user_group_link, back_populates="groups")
roles = relationship("Role", secondary=role_group_link, back_populates="groups")
@event.listens_for(UserGroup.__table__, 'after_create') @event.listens_for(UserGroup.__table__, 'after_create')
def create_default_user_groups(target, connection, **kw): def create_default_user_groups(target, connection, **kw):
default_groups = [ default_groups = [
{"name": "Administrators", "is_updatable": False}, {"name": "Administrators", "builtin": True},
{"name": "Editors", "is_updatable": False}, {"name": "Users", "builtin": True}
{"name": "Users", "is_updatable": False}
] ]
stmt = target.insert().values(default_groups) stmt = target.insert().values(default_groups)
connection.execute(stmt) connection.execute(stmt)
connection.commit() connection.commit()
log.info("The default user groups have been created in the database") log.debug("The default user groups have been created in the database")
@event.listens_for(users_group_members, 'after_create') # @event.listens_for(user_group_link, 'after_create')
def add_admin_to_group(target, connection, **kw): # def add_admin_to_group(target, connection, **kw):
#
users_group_table = UserGroup.__table__ # user_groups_table = UserGroup.__table__
stmt = users_group_table.select().where(users_group_table.c.name == "Administrators") # stmt = user_groups_table.select().where(user_groups_table.c.name == "Administrators")
result = connection.execute(stmt) # result = connection.execute(stmt)
user_group_id = result.first().user_group_id # user_group_id = result.first().user_group_id
#
users_table = User.__table__ # users_table = User.__table__
stmt = users_table.select().where(users_table.c.is_superadmin.is_(True)) # stmt = users_table.select().where(users_table.c.is_superadmin.is_(True))
result = connection.execute(stmt) # result = connection.execute(stmt)
user_id = result.first().user_id # user_id = result.first().user_id
#
stmt = target.insert().values(user_id=user_id, user_group_id=user_group_id) # stmt = target.insert().values(user_id=user_id, user_group_id=user_group_id)
connection.execute(stmt) # connection.execute(stmt)
connection.commit() # connection.commit()

View File

@ -0,0 +1,409 @@
#!/usr/bin/env python
#
# Copyright (C) 2020 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from uuid import UUID
from typing import Optional, List, Union
from sqlalchemy import select, update, delete
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from .base import BaseRepository
import gns3server.db.models as models
from gns3server.schemas.controller.rbac import HTTPMethods, PermissionAction
from gns3server import schemas
import logging
log = logging.getLogger(__name__)
class RbacRepository(BaseRepository):
def __init__(self, db_session: AsyncSession) -> None:
super().__init__(db_session)
async def get_role(self, role_id: UUID) -> Optional[models.Role]:
"""
Get a role by its ID.
"""
query = select(models.Role).\
options(selectinload(models.Role.permissions)).\
where(models.Role.role_id == role_id)
result = await self._db_session.execute(query)
return result.scalars().first()
async def get_role_by_name(self, name: str) -> Optional[models.Role]:
"""
Get a role by its name.
"""
query = select(models.Role).\
options(selectinload(models.Role.permissions)).\
where(models.Role.name == name)
#query = select(models.Role).where(models.Role.name == name)
result = await self._db_session.execute(query)
return result.scalars().first()
async def get_roles(self) -> List[models.Role]:
"""
Get all roles.
"""
query = select(models.Role).options(selectinload(models.Role.permissions))
result = await self._db_session.execute(query)
return result.scalars().all()
async def create_role(self, role_create: schemas.RoleCreate) -> models.Role:
"""
Create a new role.
"""
db_role = models.Role(
name=role_create.name,
description=role_create.description,
)
self._db_session.add(db_role)
await self._db_session.commit()
#await self._db_session.refresh(db_role)
return await self.get_role(db_role.role_id)
async def update_role(
self,
role_id: UUID,
role_update: schemas.RoleUpdate
) -> Optional[models.Role]:
"""
Update a role.
"""
update_values = role_update.dict(exclude_unset=True)
query = update(models.Role).\
where(models.Role.role_id == role_id).\
values(update_values)
await self._db_session.execute(query)
await self._db_session.commit()
role_db = await self.get_role(role_id)
if role_db:
await self._db_session.refresh(role_db) # force refresh of updated_at value
return role_db
async def delete_role(self, role_id: UUID) -> bool:
"""
Delete a role.
"""
query = delete(models.Role).where(models.Role.role_id == role_id)
result = await self._db_session.execute(query)
await self._db_session.commit()
return result.rowcount > 0
async def add_permission_to_role(
self,
role_id: UUID,
permission: models.Permission
) -> Union[None, models.Role]:
"""
Add a permission to a role.
"""
query = select(models.Role).\
options(selectinload(models.Role.permissions)).\
where(models.Role.role_id == role_id)
result = await self._db_session.execute(query)
role_db = result.scalars().first()
if not role_db:
return None
role_db.permissions.append(permission)
await self._db_session.commit()
await self._db_session.refresh(role_db)
return role_db
async def remove_permission_from_role(
self,
role_id: UUID,
permission: models.Permission
) -> Union[None, models.Role]:
"""
Remove a permission from a role.
"""
query = select(models.Role).\
options(selectinload(models.Role.permissions)).\
where(models.Role.role_id == role_id)
result = await self._db_session.execute(query)
role_db = result.scalars().first()
if not role_db:
return None
role_db.permissions.remove(permission)
await self._db_session.commit()
await self._db_session.refresh(role_db)
return role_db
async def get_role_permissions(self, role_id: UUID) -> List[models.Permission]:
"""
Get all the role permissions.
"""
query = select(models.Permission).\
join(models.Permission.roles).\
filter(models.Role.role_id == role_id)
result = await self._db_session.execute(query)
return result.scalars().all()
async def get_permission(self, permission_id: UUID) -> Optional[models.Permission]:
"""
Get a permission by its ID.
"""
query = select(models.Permission).where(models.Permission.permission_id == permission_id)
result = await self._db_session.execute(query)
return result.scalars().first()
async def get_permission_by_path(self, path: str) -> Optional[models.Permission]:
"""
Get a permission by its path.
"""
query = select(models.Permission).where(models.Permission.path == path)
result = await self._db_session.execute(query)
return result.scalars().first()
async def get_permissions(self) -> List[models.Permission]:
"""
Get all permissions.
"""
query = select(models.Permission)
result = await self._db_session.execute(query)
return result.scalars().all()
async def check_permission_exists(self, permission_create: schemas.PermissionCreate) -> bool:
"""
Check if a permission exists.
"""
query = select(models.Permission).\
where(models.Permission.methods == permission_create.methods,
models.Permission.path == permission_create.path,
models.Permission.action == permission_create.action)
result = await self._db_session.execute(query)
return result.scalars().first() is not None
async def create_permission(self, permission_create: schemas.PermissionCreate) -> models.Permission:
"""
Create a new permission.
"""
db_permission = models.Permission(
description=permission_create.description,
methods=permission_create.methods,
path=permission_create.path,
action=permission_create.action,
)
self._db_session.add(db_permission)
await self._db_session.commit()
await self._db_session.refresh(db_permission)
return db_permission
async def update_permission(
self,
permission_id: UUID,
permission_update: schemas.PermissionUpdate
) -> Optional[models.Permission]:
"""
Update a permission.
"""
update_values = permission_update.dict(exclude_unset=True)
query = update(models.Permission).\
where(models.Permission.permission_id == permission_id).\
values(update_values)
await self._db_session.execute(query)
await self._db_session.commit()
permission_db = await self.get_permission(permission_id)
if permission_db:
await self._db_session.refresh(permission_db) # force refresh of updated_at value
return permission_db
async def delete_permission(self, permission_id: UUID) -> bool:
"""
Delete a permission.
"""
query = delete(models.Permission).where(models.Permission.permission_id == permission_id)
result = await self._db_session.execute(query)
await self._db_session.commit()
return result.rowcount > 0
def _match_permission(
self,
permissions: List[models.Permission],
method: str,
path: str
) -> Union[None, models.Permission]:
"""
Match the methods and path with a permission.
"""
for permission in permissions:
log.debug(f"RBAC: checking permission {permission.methods} {permission.path} {permission.action}")
if method not in permission.methods:
continue
if permission.path.endswith("/*") and path.startswith(permission.path[:-2]):
return permission
elif permission.path == path:
return permission
async def get_user_permissions(self, user_id: UUID):
"""
Get all permissions from an user.
"""
query = select(models.Permission).\
join(models.User.permissions). \
filter(models.User.user_id == user_id).\
order_by(models.Permission.path)
result = await self._db_session.execute(query)
return result.scalars().all()
async def add_permission_to_user(
self,
user_id: UUID,
permission: models.Permission
) -> Union[None, models.User]:
"""
Add a permission to an user.
"""
query = select(models.User).\
options(selectinload(models.User.permissions)).\
where(models.User.user_id == user_id)
result = await self._db_session.execute(query)
user_db = result.scalars().first()
if not user_db:
return None
user_db.permissions.append(permission)
await self._db_session.commit()
await self._db_session.refresh(user_db)
return user_db
async def remove_permission_from_user(
self,
user_id: UUID,
permission: models.Permission
) -> Union[None, models.User]:
"""
Remove a permission from a role.
"""
query = select(models.User).\
options(selectinload(models.User.permissions)).\
where(models.User.user_id == user_id)
result = await self._db_session.execute(query)
user_db = result.scalars().first()
if not user_db:
return None
user_db.permissions.remove(permission)
await self._db_session.commit()
await self._db_session.refresh(user_db)
return user_db
async def add_permission_to_user_with_path(self, user_id: UUID, path: str) -> Union[None, models.User]:
"""
Add a permission to an user.
"""
# Create a new permission with full rights on path
new_permission = schemas.PermissionCreate(
description=f"Allow access to {path}",
methods=[HTTPMethods.get, HTTPMethods.head, HTTPMethods.post, HTTPMethods.put, HTTPMethods.delete],
path=path,
action=PermissionAction.allow
)
permission_db = await self.create_permission(new_permission)
# Add the permission to the user
query = select(models.User).\
options(selectinload(models.User.permissions)).\
where(models.User.user_id == user_id)
result = await self._db_session.execute(query)
user_db = result.scalars().first()
if not user_db:
return None
user_db.permissions.append(permission_db)
await self._db_session.commit()
await self._db_session.refresh(user_db)
return user_db
async def delete_all_permissions_with_path(self, path: str) -> None:
"""
Delete all permissions with path.
"""
query = delete(models.Permission).\
where(models.Permission.path.startswith(path)).\
execution_options(synchronize_session=False)
result = await self._db_session.execute(query)
log.debug(f"{result.rowcount} permission(s) have been deleted")
async def check_user_is_authorized(self, user_id: UUID, method: str, path: str) -> bool:
"""
Check if an user is authorized to access a resource.
"""
query = select(models.Permission).\
join(models.Permission.roles). \
join(models.Role.groups). \
join(models.UserGroup.users). \
filter(models.User.user_id == user_id).\
order_by(models.Permission.path)
result = await self._db_session.execute(query)
permissions = result.scalars().all()
log.debug(f"RBAC: checking authorization for user '{user_id}' on {method} '{path}'")
matched_permission = self._match_permission(permissions, method, path)
if matched_permission:
log.debug(f"RBAC: matched role permission {matched_permission.methods} "
f"{matched_permission.path} {matched_permission.action}")
if matched_permission.action == "DENY":
return False
return True
log.debug(f"RBAC: could not find a role permission, checking user permissions...")
permissions = await self.get_user_permissions(user_id)
matched_permission = self._match_permission(permissions, method, path)
if matched_permission:
log.debug(f"RBAC: matched user permission {matched_permission.methods} "
f"{matched_permission.path} {matched_permission.action}")
if matched_permission.action == "DENY":
return False
return True
return False

View File

@ -33,40 +33,59 @@ log = logging.getLogger(__name__)
class UsersRepository(BaseRepository): class UsersRepository(BaseRepository):
def __init__(self, db_session: AsyncSession) -> None: def __init__(self, db_session: AsyncSession) -> None:
super().__init__(db_session) super().__init__(db_session)
self._auth_service = auth_service self._auth_service = auth_service
async def get_user(self, user_id: UUID) -> Optional[models.User]: async def get_user(self, user_id: UUID) -> Optional[models.User]:
"""
Get an user by its ID.
"""
query = select(models.User).where(models.User.user_id == user_id) query = select(models.User).where(models.User.user_id == user_id)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().first() return result.scalars().first()
async def get_user_by_username(self, username: str) -> Optional[models.User]: async def get_user_by_username(self, username: str) -> Optional[models.User]:
"""
Get an user by its name.
"""
query = select(models.User).where(models.User.username == username) query = select(models.User).where(models.User.username == username)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().first() return result.scalars().first()
async def get_user_by_email(self, email: str) -> Optional[models.User]: async def get_user_by_email(self, email: str) -> Optional[models.User]:
"""
Get an user by its email.
"""
query = select(models.User).where(models.User.email == email) query = select(models.User).where(models.User.email == email)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().first() return result.scalars().first()
async def get_users(self) -> List[models.User]: async def get_users(self) -> List[models.User]:
"""
Get all users.
"""
query = select(models.User) query = select(models.User)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().all() return result.scalars().all()
async def create_user(self, user: schemas.UserCreate) -> models.User: async def create_user(self, user: schemas.UserCreate) -> models.User:
"""
Create a new user.
"""
hashed_password = self._auth_service.hash_password(user.password.get_secret_value()) hashed_password = self._auth_service.hash_password(user.password.get_secret_value())
db_user = models.User( db_user = models.User(
username=user.username, email=user.email, full_name=user.full_name, hashed_password=hashed_password username=user.username,
email=user.email,
full_name=user.full_name,
hashed_password=hashed_password
) )
self._db_session.add(db_user) self._db_session.add(db_user)
await self._db_session.commit() await self._db_session.commit()
@ -74,6 +93,9 @@ class UsersRepository(BaseRepository):
return db_user return db_user
async def update_user(self, user_id: UUID, user_update: schemas.UserUpdate) -> Optional[models.User]: async def update_user(self, user_id: UUID, user_update: schemas.UserUpdate) -> Optional[models.User]:
"""
Update an user.
"""
update_values = user_update.dict(exclude_unset=True) update_values = user_update.dict(exclude_unset=True)
password = update_values.pop("password", None) password = update_values.pop("password", None)
@ -92,6 +114,9 @@ class UsersRepository(BaseRepository):
return user_db return user_db
async def delete_user(self, user_id: UUID) -> bool: async def delete_user(self, user_id: UUID) -> bool:
"""
Delete an user.
"""
query = delete(models.User).where(models.User.user_id == user_id) query = delete(models.User).where(models.User.user_id == user_id)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
@ -99,6 +124,9 @@ class UsersRepository(BaseRepository):
return result.rowcount > 0 return result.rowcount > 0
async def authenticate_user(self, username: str, password: str) -> Optional[models.User]: async def authenticate_user(self, username: str, password: str) -> Optional[models.User]:
"""
Authenticate an user.
"""
user = await self.get_user_by_username(username) user = await self.get_user_by_username(username)
if not user: if not user:
@ -115,6 +143,9 @@ class UsersRepository(BaseRepository):
return user return user
async def get_user_memberships(self, user_id: UUID) -> List[models.UserGroup]: async def get_user_memberships(self, user_id: UUID) -> List[models.UserGroup]:
"""
Get all user memberships (user groups).
"""
query = select(models.UserGroup).\ query = select(models.UserGroup).\
join(models.UserGroup.users).\ join(models.UserGroup.users).\
@ -124,24 +155,36 @@ class UsersRepository(BaseRepository):
return result.scalars().all() return result.scalars().all()
async def get_user_group(self, user_group_id: UUID) -> Optional[models.UserGroup]: async def get_user_group(self, user_group_id: UUID) -> Optional[models.UserGroup]:
"""
Get an user group by its ID.
"""
query = select(models.UserGroup).where(models.UserGroup.user_group_id == user_group_id) query = select(models.UserGroup).where(models.UserGroup.user_group_id == user_group_id)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().first() return result.scalars().first()
async def get_user_group_by_name(self, name: str) -> Optional[models.UserGroup]: async def get_user_group_by_name(self, name: str) -> Optional[models.UserGroup]:
"""
Get an user group by its name.
"""
query = select(models.UserGroup).where(models.UserGroup.name == name) query = select(models.UserGroup).where(models.UserGroup.name == name)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().first() return result.scalars().first()
async def get_user_groups(self) -> List[models.UserGroup]: async def get_user_groups(self) -> List[models.UserGroup]:
"""
Get all user groups.
"""
query = select(models.UserGroup) query = select(models.UserGroup)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().all() return result.scalars().all()
async def create_user_group(self, user_group: schemas.UserGroupCreate) -> models.UserGroup: async def create_user_group(self, user_group: schemas.UserGroupCreate) -> models.UserGroup:
"""
Create a new user group.
"""
db_user_group = models.UserGroup(name=user_group.name) db_user_group = models.UserGroup(name=user_group.name)
self._db_session.add(db_user_group) self._db_session.add(db_user_group)
@ -154,6 +197,9 @@ class UsersRepository(BaseRepository):
user_group_id: UUID, user_group_id: UUID,
user_group_update: schemas.UserGroupUpdate user_group_update: schemas.UserGroupUpdate
) -> Optional[models.UserGroup]: ) -> Optional[models.UserGroup]:
"""
Update an user group.
"""
update_values = user_group_update.dict(exclude_unset=True) update_values = user_group_update.dict(exclude_unset=True)
query = update(models.UserGroup).\ query = update(models.UserGroup).\
@ -168,6 +214,9 @@ class UsersRepository(BaseRepository):
return user_group_db return user_group_db
async def delete_user_group(self, user_group_id: UUID) -> bool: async def delete_user_group(self, user_group_id: UUID) -> bool:
"""
Delete an user group.
"""
query = delete(models.UserGroup).where(models.UserGroup.user_group_id == user_group_id) query = delete(models.UserGroup).where(models.UserGroup.user_group_id == user_group_id)
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
@ -179,6 +228,9 @@ class UsersRepository(BaseRepository):
user_group_id: UUID, user_group_id: UUID,
user: models.User user: models.User
) -> Union[None, models.UserGroup]: ) -> Union[None, models.UserGroup]:
"""
Add a member to an user group.
"""
query = select(models.UserGroup).\ query = select(models.UserGroup).\
options(selectinload(models.UserGroup.users)).\ options(selectinload(models.UserGroup.users)).\
@ -198,6 +250,9 @@ class UsersRepository(BaseRepository):
user_group_id: UUID, user_group_id: UUID,
user: models.User user: models.User
) -> Union[None, models.UserGroup]: ) -> Union[None, models.UserGroup]:
"""
Remove a member from an user group.
"""
query = select(models.UserGroup).\ query = select(models.UserGroup).\
options(selectinload(models.UserGroup.users)).\ options(selectinload(models.UserGroup.users)).\
@ -213,6 +268,9 @@ class UsersRepository(BaseRepository):
return user_group_db return user_group_db
async def get_user_group_members(self, user_group_id: UUID) -> List[models.User]: async def get_user_group_members(self, user_group_id: UUID) -> List[models.User]:
"""
Get all members from an user group.
"""
query = select(models.User).\ query = select(models.User).\
join(models.User.groups).\ join(models.User.groups).\
@ -220,3 +278,60 @@ class UsersRepository(BaseRepository):
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
return result.scalars().all() return result.scalars().all()
async def add_role_to_user_group(
self,
user_group_id: UUID,
role: models.Role
) -> Union[None, models.UserGroup]:
"""
Add a role to an user group.
"""
query = select(models.UserGroup).\
options(selectinload(models.UserGroup.roles)).\
where(models.UserGroup.user_group_id == user_group_id)
result = await self._db_session.execute(query)
user_group_db = result.scalars().first()
if not user_group_db:
return None
user_group_db.roles.append(role)
await self._db_session.commit()
await self._db_session.refresh(user_group_db)
return user_group_db
async def remove_role_from_user_group(
self,
user_group_id: UUID,
role: models.Role
) -> Union[None, models.UserGroup]:
"""
Remove a role from an user group.
"""
query = select(models.UserGroup).\
options(selectinload(models.UserGroup.roles)).\
where(models.UserGroup.user_group_id == user_group_id)
result = await self._db_session.execute(query)
user_group_db = result.scalars().first()
if not user_group_db:
return None
user_group_db.roles.remove(role)
await self._db_session.commit()
await self._db_session.refresh(user_group_db)
return user_group_db
async def get_user_group_roles(self, user_group_id: UUID) -> List[models.Role]:
"""
Get all roles from an user group.
"""
query = select(models.Role). \
options(selectinload(models.Role.permissions)). \
join(models.UserGroup.roles). \
filter(models.UserGroup.user_group_id == user_group_id)
result = await self._db_session.execute(query)
return result.scalars().all()

View File

@ -28,6 +28,7 @@ from .controller.gns3vm import GNS3VM
from .controller.nodes import NodeCreate, NodeUpdate, NodeDuplicate, NodeCapture, Node from .controller.nodes import NodeCreate, NodeUpdate, NodeDuplicate, NodeCapture, Node
from .controller.projects import ProjectCreate, ProjectUpdate, ProjectDuplicate, Project, ProjectFile from .controller.projects import ProjectCreate, ProjectUpdate, ProjectDuplicate, Project, ProjectFile
from .controller.users import UserCreate, UserUpdate, User, Credentials, UserGroupCreate, UserGroupUpdate, UserGroup from .controller.users import UserCreate, UserUpdate, User, Credentials, UserGroupCreate, UserGroupUpdate, UserGroup
from .controller.rbac import RoleCreate, RoleUpdate, Role, PermissionCreate, PermissionUpdate, Permission
from .controller.tokens import Token from .controller.tokens import Token
from .controller.snapshots import SnapshotCreate, Snapshot from .controller.snapshots import SnapshotCreate, Snapshot
from .controller.iou_license import IOULicense from .controller.iou_license import IOULicense

View File

@ -0,0 +1,121 @@
#
# Copyright (C) 2020 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from typing import Optional, List
from pydantic import BaseModel, validator
from uuid import UUID
from enum import Enum
from .base import DateTimeModelMixin
class HTTPMethods(str, Enum):
"""
HTTP method type.
"""
get = "GET"
head = "HEAD"
post = "POST"
patch = "PATCH"
put = "PUT"
delete = "DELETE"
class PermissionAction(str, Enum):
"""
Action to perform when permission is matched.
"""
allow = "ALLOW"
deny = "DENY"
class PermissionBase(BaseModel):
"""
Common permission properties.
"""
methods: List[HTTPMethods]
path: str
action: PermissionAction
description: Optional[str] = None
class Config:
use_enum_values = True
@validator("action", pre=True)
def action_uppercase(cls, v):
return v.upper()
class PermissionCreate(PermissionBase):
"""
Properties to create a permission.
"""
pass
class PermissionUpdate(PermissionBase):
"""
Properties to update a role.
"""
pass
class Permission(DateTimeModelMixin, PermissionBase):
permission_id: UUID
class Config:
orm_mode = True
class RoleBase(BaseModel):
"""
Common role properties.
"""
name: Optional[str] = None
description: Optional[str] = None
class RoleCreate(RoleBase):
"""
Properties to create a role.
"""
name: str
class RoleUpdate(RoleBase):
"""
Properties to update a role.
"""
pass
class Role(DateTimeModelMixin, RoleBase):
role_id: UUID
builtin: bool
permissions: List[Permission]
class Config:
orm_mode = True

View File

@ -85,7 +85,7 @@ class UserGroupUpdate(UserGroupBase):
class UserGroup(DateTimeModelMixin, UserGroupBase): class UserGroup(DateTimeModelMixin, UserGroupBase):
user_group_id: UUID user_group_id: UUID
is_updatable: bool builtin: bool
class Config: class Config:
orm_mode = True orm_mode = True

View File

@ -22,7 +22,10 @@ from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from gns3server.db.repositories.users import UsersRepository from gns3server.db.repositories.users import UsersRepository
from gns3server.db.repositories.rbac import RbacRepository
from gns3server.schemas.controller.users import User from gns3server.schemas.controller.users import User
from gns3server.schemas.controller.rbac import Role
from gns3server import schemas
pytestmark = pytest.mark.asyncio pytestmark = pytest.mark.asyncio
@ -47,7 +50,7 @@ class TestGroupRoutes:
response = await client.get(app.url_path_for("get_user_groups")) response = await client.get(app.url_path_for("get_user_groups"))
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 4 # 3 default groups + group1 assert len(response.json()) == 3 # 2 default groups + group1
async def test_update_group(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None: async def test_update_group(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None:
@ -103,6 +106,9 @@ class TestGroupRoutes:
response = await client.delete(app.url_path_for("delete_user_group", user_group_id=group_in_db.user_group_id)) response = await client.delete(app.url_path_for("delete_user_group", user_group_id=group_in_db.user_group_id))
assert response.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN
class TestGroupMembersRoutes:
async def test_add_member_to_group( async def test_add_member_to_group(
self, self,
app: FastAPI, app: FastAPI,
@ -163,3 +169,84 @@ class TestGroupRoutes:
assert response.status_code == status.HTTP_204_NO_CONTENT assert response.status_code == status.HTTP_204_NO_CONTENT
members = await user_repo.get_user_group_members(group_in_db.user_group_id) members = await user_repo.get_user_group_members(group_in_db.user_group_id)
assert len(members) == 0 assert len(members) == 0
@pytest.fixture
async def test_role(db_session: AsyncSession) -> Role:
new_role = schemas.RoleCreate(
name="TestRole",
description="This is my test role"
)
rbac_repo = RbacRepository(db_session)
existing_role = await rbac_repo.get_role_by_name(new_role.name)
if existing_role:
return existing_role
return await rbac_repo.create_role(new_role)
class TestGroupRolesRoutes:
async def test_add_role_to_group(
self,
app: FastAPI,
client: AsyncClient,
test_role: Role,
db_session: AsyncSession
) -> None:
user_repo = UsersRepository(db_session)
group_in_db = await user_repo.get_user_group_by_name("Users")
response = await client.put(
app.url_path_for(
"add_role_to_group",
user_group_id=group_in_db.user_group_id,
role_id=str(test_role.role_id)
)
)
assert response.status_code == status.HTTP_204_NO_CONTENT
roles = await user_repo.get_user_group_roles(group_in_db.user_group_id)
assert len(roles) == 2 # 1 default role + 1 custom role
for role in roles:
if not role.builtin:
assert role.name == test_role.name
async def test_get_user_group_roles(
self,
app: FastAPI,
client: AsyncClient,
db_session: AsyncSession
) -> None:
user_repo = UsersRepository(db_session)
group_in_db = await user_repo.get_user_group_by_name("Users")
response = await client.get(
app.url_path_for(
"get_user_group_roles",
user_group_id=group_in_db.user_group_id)
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 2 # 1 default role + 1 custom role
async def test_remove_role_from_group(
self,
app: FastAPI,
client: AsyncClient,
test_role: Role,
db_session: AsyncSession
) -> None:
user_repo = UsersRepository(db_session)
group_in_db = await user_repo.get_user_group_by_name("Users")
response = await client.delete(
app.url_path_for(
"remove_role_from_group",
user_group_id=group_in_db.user_group_id,
role_id=test_role.role_id
),
)
assert response.status_code == status.HTTP_204_NO_CONTENT
roles = await user_repo.get_user_group_roles(group_in_db.user_group_id)
assert len(roles) == 1 # 1 default role
assert roles[0].name != test_role.name

View File

@ -0,0 +1,83 @@
#!/usr/bin/env python
#
# Copyright (C) 2021 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import pytest
from fastapi import FastAPI, status
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from gns3server.db.repositories.rbac import RbacRepository
pytestmark = pytest.mark.asyncio
class TestPermissionRoutes:
async def test_create_permission(self, app: FastAPI, client: AsyncClient) -> None:
new_permission = {
"methods": ["GET"],
"path": "/templates",
"action": "ALLOW"
}
response = await client.post(app.url_path_for("create_permission"), json=new_permission)
assert response.status_code == status.HTTP_201_CREATED
async def test_get_permission(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None:
rbac_repo = RbacRepository(db_session)
permission_in_db = await rbac_repo.get_permission_by_path("/templates")
response = await client.get(app.url_path_for("get_permission", permission_id=permission_in_db.permission_id))
assert response.status_code == status.HTTP_200_OK
assert response.json()["permission_id"] == str(permission_in_db.permission_id)
async def test_list_permissions(self, app: FastAPI, client: AsyncClient) -> None:
response = await client.get(app.url_path_for("get_permissions"))
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 6 # 5 default permissions + 1 custom permission
async def test_update_permission(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None:
rbac_repo = RbacRepository(db_session)
permission_in_db = await rbac_repo.get_permission_by_path("/templates")
update_permission = {
"methods": ["GET"],
"path": "/appliances",
"action": "ALLOW"
}
response = await client.put(
app.url_path_for("update_permission", permission_id=permission_in_db.permission_id),
json=update_permission
)
assert response.status_code == status.HTTP_200_OK
updated_permission_in_db = await rbac_repo.get_permission(permission_in_db.permission_id)
assert updated_permission_in_db.path == "/appliances"
async def test_delete_permission(
self,
app: FastAPI,
client: AsyncClient,
db_session: AsyncSession
) -> None:
rbac_repo = RbacRepository(db_session)
permission_in_db = await rbac_repo.get_permission_by_path("/appliances")
response = await client.delete(app.url_path_for("delete_permission", permission_id=permission_in_db.permission_id))
assert response.status_code == status.HTTP_204_NO_CONTENT

View File

@ -0,0 +1,185 @@
#!/usr/bin/env python
#
# Copyright (C) 2021 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import pytest
from fastapi import FastAPI, status
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from gns3server.db.repositories.rbac import RbacRepository
from gns3server.schemas.controller.rbac import Permission, HTTPMethods, PermissionAction
from gns3server import schemas
pytestmark = pytest.mark.asyncio
class TestRolesRoutes:
async def test_create_role(self, app: FastAPI, client: AsyncClient) -> None:
new_role = {"name": "role1"}
response = await client.post(app.url_path_for("create_role"), json=new_role)
assert response.status_code == status.HTTP_201_CREATED
async def test_get_role(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None:
rbac_repo = RbacRepository(db_session)
role_in_db = await rbac_repo.get_role_by_name("role1")
response = await client.get(app.url_path_for("get_role", role_id=role_in_db.role_id))
assert response.status_code == status.HTTP_200_OK
assert response.json()["role_id"] == str(role_in_db.role_id)
async def test_list_roles(self, app: FastAPI, client: AsyncClient) -> None:
response = await client.get(app.url_path_for("get_roles"))
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 3 # 2 default roles + role1
async def test_update_role(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None:
rbac_repo = RbacRepository(db_session)
role_in_db = await rbac_repo.get_role_by_name("role1")
update_role = {"name": "role42"}
response = await client.put(
app.url_path_for("update_role", role_id=role_in_db.role_id),
json=update_role
)
assert response.status_code == status.HTTP_200_OK
updated_role_in_db = await rbac_repo.get_role(role_in_db.role_id)
assert updated_role_in_db.name == "role42"
async def test_cannot_update_builtin_user_role(
self,
app: FastAPI,
client: AsyncClient,
db_session: AsyncSession
) -> None:
rbac_repo = RbacRepository(db_session)
role_in_db = await rbac_repo.get_role_by_name("User")
update_role = {"name": "Hackers"}
response = await client.put(
app.url_path_for("update_role", role_id=role_in_db.role_id),
json=update_role
)
assert response.status_code == status.HTTP_403_FORBIDDEN
async def test_delete_role(
self,
app: FastAPI,
client: AsyncClient,
db_session: AsyncSession
) -> None:
rbac_repo = RbacRepository(db_session)
role_in_db = await rbac_repo.get_role_by_name("role42")
response = await client.delete(app.url_path_for("delete_role", role_id=role_in_db.role_id))
assert response.status_code == status.HTTP_204_NO_CONTENT
async def test_cannot_delete_builtin_administrator_role(
self,
app: FastAPI,
client: AsyncClient,
db_session: AsyncSession
) -> None:
rbac_repo = RbacRepository(db_session)
role_in_db = await rbac_repo.get_role_by_name("Administrator")
response = await client.delete(app.url_path_for("delete_role", role_id=role_in_db.role_id))
assert response.status_code == status.HTTP_403_FORBIDDEN
@pytest.fixture
async def test_permission(db_session: AsyncSession) -> Permission:
new_permission = schemas.PermissionCreate(
methods=[HTTPMethods.get],
path="/statistics",
action=PermissionAction.allow
)
rbac_repo = RbacRepository(db_session)
existing_permission = await rbac_repo.get_permission_by_path("/statistics")
if existing_permission:
return existing_permission
return await rbac_repo.create_permission(new_permission)
class TestRolesPermissionsRoutes:
async def test_add_permission_to_role(
self,
app: FastAPI,
client: AsyncClient,
test_permission: Permission,
db_session: AsyncSession
) -> None:
rbac_repo = RbacRepository(db_session)
role_in_db = await rbac_repo.get_role_by_name("User")
response = await client.put(
app.url_path_for(
"add_permission_to_role",
role_id=role_in_db.role_id,
permission_id=str(test_permission.permission_id)
)
)
assert response.status_code == status.HTTP_204_NO_CONTENT
permissions = await rbac_repo.get_role_permissions(role_in_db.role_id)
assert len(permissions) == 5 # 4 default permissions + 1 custom permission
async def test_get_role_permissions(
self,
app: FastAPI,
client: AsyncClient,
db_session: AsyncSession
) -> None:
rbac_repo = RbacRepository(db_session)
role_in_db = await rbac_repo.get_role_by_name("User")
response = await client.get(
app.url_path_for(
"get_role_permissions",
role_id=role_in_db.role_id)
)
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 5 # 4 default permissions + 1 custom permission
async def test_remove_role_from_group(
self,
app: FastAPI,
client: AsyncClient,
test_permission: Permission,
db_session: AsyncSession
) -> None:
rbac_repo = RbacRepository(db_session)
role_in_db = await rbac_repo.get_role_by_name("User")
response = await client.delete(
app.url_path_for(
"remove_permission_from_role",
role_id=role_in_db.role_id,
permission_id=str(test_permission.permission_id)
),
)
assert response.status_code == status.HTTP_204_NO_CONTENT
permissions = await rbac_repo.get_role_permissions(role_in_db.role_id)
assert len(permissions) == 4 # 4 default permissions

View File

@ -25,9 +25,12 @@ from jose import jwt
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from gns3server.db.repositories.users import UsersRepository from gns3server.db.repositories.users import UsersRepository
from gns3server.db.repositories.rbac import RbacRepository
from gns3server.schemas.controller.rbac import Permission, HTTPMethods, PermissionAction
from gns3server.services import auth_service from gns3server.services import auth_service
from gns3server.config import Config from gns3server.config import Config
from gns3server.schemas.controller.users import User from gns3server.schemas.controller.users import User
from gns3server import schemas
import gns3server.db.models as models import gns3server.db.models as models
pytestmark = pytest.mark.asyncio pytestmark = pytest.mark.asyncio
@ -266,7 +269,7 @@ class TestUserMe:
test_user: User, test_user: User,
) -> None: ) -> None:
response = await authorized_client.get(app.url_path_for("get_current_active_user")) response = await authorized_client.get(app.url_path_for("get_logged_in_user"))
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
user = User(**response.json()) user = User(**response.json())
assert user.username == test_user.username assert user.username == test_user.username
@ -279,7 +282,7 @@ class TestUserMe:
test_user: User, test_user: User,
) -> None: ) -> None:
response = await unauthorized_client.get(app.url_path_for("get_current_active_user")) response = await unauthorized_client.get(app.url_path_for("get_logged_in_user"))
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
@ -329,15 +332,91 @@ class TestSuperAdmin:
response = await unauthorized_client.post(app.url_path_for("login"), data=login_data) response = await unauthorized_client.post(app.url_path_for("login"), data=login_data)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
async def test_super_admin_belongs_to_admin_group( # async def test_super_admin_belongs_to_admin_group(
# self,
# app: FastAPI,
# client: AsyncClient,
# db_session: AsyncSession
# ) -> None:
#
# user_repo = UsersRepository(db_session)
# admin_in_db = await user_repo.get_user_by_username("admin")
# response = await client.get(app.url_path_for("get_user_memberships", user_id=admin_in_db.user_id))
# assert response.status_code == status.HTTP_200_OK
# assert len(response.json()) == 1
@pytest.fixture
async def test_permission(db_session: AsyncSession) -> Permission:
new_permission = schemas.PermissionCreate(
methods=[HTTPMethods.get],
path="/statistics",
action=PermissionAction.allow
)
rbac_repo = RbacRepository(db_session)
existing_permission = await rbac_repo.get_permission_by_path("/statistics")
if existing_permission:
return existing_permission
return await rbac_repo.create_permission(new_permission)
class TestUserPermissionsRoutes:
async def test_add_permission_to_user(
self, self,
app: FastAPI, app: FastAPI,
client: AsyncClient, client: AsyncClient,
test_user: User,
test_permission: Permission,
db_session: AsyncSession db_session: AsyncSession
) -> None: ) -> None:
user_repo = UsersRepository(db_session) response = await client.put(
admin_in_db = await user_repo.get_user_by_username("admin") app.url_path_for(
response = await client.get(app.url_path_for("get_user_memberships", user_id=admin_in_db.user_id)) "add_permission_to_user",
user_id=str(test_user.user_id),
permission_id=str(test_permission.permission_id)
)
)
assert response.status_code == status.HTTP_204_NO_CONTENT
rbac_repo = RbacRepository(db_session)
permissions = await rbac_repo.get_user_permissions(test_user.user_id)
assert len(permissions) == 1
assert permissions[0].permission_id == test_permission.permission_id
async def test_get_user_permissions(
self,
app: FastAPI,
client: AsyncClient,
test_user: User,
db_session: AsyncSession
) -> None:
response = await client.get(
app.url_path_for(
"get_user_permissions",
user_id=str(test_user.user_id))
)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 1 assert len(response.json()) == 1
async def test_remove_permission_from_user(
self,
app: FastAPI,
client: AsyncClient,
test_user: User,
test_permission: Permission,
db_session: AsyncSession
) -> None:
response = await client.delete(
app.url_path_for(
"remove_permission_from_user",
user_id=str(test_user.user_id),
permission_id=str(test_permission.permission_id)
),
)
assert response.status_code == status.HTTP_204_NO_CONTENT
rbac_repo = RbacRepository(db_session)
permissions = await rbac_repo.get_user_permissions(test_user.user_id)
assert len(permissions) == 0

View File

@ -124,7 +124,12 @@ async def test_user(db_session: AsyncSession) -> User:
existing_user = await user_repo.get_user_by_username(new_user.username) existing_user = await user_repo.get_user_by_username(new_user.username)
if existing_user: if existing_user:
return existing_user return existing_user
return await user_repo.create_user(new_user) user = await user_repo.create_user(new_user)
# add new user to "Users group
group = await user_repo.get_user_group_by_name("Users")
await user_repo.add_member_to_user_group(group.user_group_id, user)
return user
@pytest.fixture @pytest.fixture

View File

@ -0,0 +1,203 @@
#!/usr/bin/env python
#
# Copyright (C) 2021 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import pytest
from fastapi import FastAPI, status
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from gns3server.db.repositories.rbac import RbacRepository
from gns3server.db.models import User
pytestmark = pytest.mark.asyncio
class TestPermissions:
@pytest.mark.parametrize(
"method, path, result",
(
("GET", "/users", False),
("GET", "/projects", True),
("GET", "/projects/e451ad73-2519-4f83-87fe-a8e821792d44", False),
("POST", "/projects", True),
("GET", "/templates", True),
("GET", "/templates/62e92cf1-244a-4486-8dae-b95439b54da9", False),
("POST", "/templates", True),
("GET", "/computes", True),
("GET", "/computes/local", True),
("GET", "/symbols", True),
("GET", "/symbols/default_symbols", True),
),
)
async def test_default_permissions_user_group(
self,
app: FastAPI,
authorized_client: AsyncClient,
test_user: User,
db_session: AsyncSession,
method: str,
path: str,
result: bool
) -> None:
rbac_repo = RbacRepository(db_session)
authorized = await rbac_repo.check_user_is_authorized(test_user.user_id, method, path)
assert authorized == result
class TestProjectsWithRbac:
async def test_admin_create_project(self, app: FastAPI, client: AsyncClient):
params = {"name": "Admin project"}
response = await client.post(app.url_path_for("create_project"), json=params)
assert response.status_code == status.HTTP_201_CREATED
async def test_user_only_access_own_projects(
self,
app: FastAPI,
authorized_client: AsyncClient,
test_user: User,
db_session: AsyncSession
) -> None:
params = {"name": "User project"}
response = await authorized_client.post(app.url_path_for("create_project"), json=params)
assert response.status_code == status.HTTP_201_CREATED
project_id = response.json()["project_id"]
rbac_repo = RbacRepository(db_session)
permissions_in_db = await rbac_repo.get_user_permissions(test_user.user_id)
assert len(permissions_in_db) == 1
assert permissions_in_db[0].path == f"/projects/{project_id}/*"
response = await authorized_client.get(app.url_path_for("get_projects"))
assert response.status_code == status.HTTP_200_OK
projects = response.json()
assert len(projects) == 1
async def test_admin_access_all_projects(self, app: FastAPI, client: AsyncClient):
response = await client.get(app.url_path_for("get_projects"))
assert response.status_code == status.HTTP_200_OK
projects = response.json()
assert len(projects) == 2
async def test_admin_user_give_permission_on_project(
self,
app: FastAPI,
client: AsyncClient,
test_user: User
):
response = await client.get(app.url_path_for("get_projects"))
assert response.status_code == status.HTTP_200_OK
projects = response.json()
project_id = None
for project in projects:
if project["name"] == "Admin project":
project_id = project["project_id"]
break
new_permission = {
"methods": ["GET"],
"path": f"/projects/{project_id}",
"action": "ALLOW"
}
response = await client.post(app.url_path_for("create_permission"), json=new_permission)
assert response.status_code == status.HTTP_201_CREATED
permission_id = response.json()["permission_id"]
response = await client.put(
app.url_path_for(
"add_permission_to_user",
user_id=test_user.user_id,
permission_id=permission_id
)
)
assert response.status_code == status.HTTP_204_NO_CONTENT
async def test_user_access_admin_project(
self,
app: FastAPI,
authorized_client: AsyncClient,
test_user: User,
db_session: AsyncSession
) -> None:
response = await authorized_client.get(app.url_path_for("get_projects"))
assert response.status_code == status.HTTP_200_OK
projects = response.json()
assert len(projects) == 2
class TestTemplatesWithRbac:
async def test_admin_create_template(self, app: FastAPI, client: AsyncClient):
new_template = {"base_script_file": "vpcs_base_config.txt",
"category": "guest",
"console_auto_start": False,
"console_type": "telnet",
"default_name_format": "PC{0}",
"name": "ADMIN_VPCS_TEMPLATE",
"compute_id": "local",
"symbol": ":/symbols/vpcs_guest.svg",
"template_type": "vpcs"}
response = await client.post(app.url_path_for("create_template"), json=new_template)
assert response.status_code == status.HTTP_201_CREATED
async def test_user_only_access_own_templates(
self, app: FastAPI,
authorized_client: AsyncClient,
test_user: User,
db_session: AsyncSession
) -> None:
new_template = {"base_script_file": "vpcs_base_config.txt",
"category": "guest",
"console_auto_start": False,
"console_type": "telnet",
"default_name_format": "PC{0}",
"name": "USER_VPCS_TEMPLATE",
"compute_id": "local",
"symbol": ":/symbols/vpcs_guest.svg",
"template_type": "vpcs"}
response = await authorized_client.post(app.url_path_for("create_template"), json=new_template)
assert response.status_code == status.HTTP_201_CREATED
template_id = response.json()["template_id"]
rbac_repo = RbacRepository(db_session)
permissions_in_db = await rbac_repo.get_user_permissions(test_user.user_id)
assert len(permissions_in_db) == 1
assert permissions_in_db[0].path == f"/templates/{template_id}/*"
response = await authorized_client.get(app.url_path_for("get_templates"))
assert response.status_code == status.HTTP_200_OK
templates = [template for template in response.json() if template["builtin"] is False]
assert len(templates) == 1
async def test_admin_access_all_templates(self, app: FastAPI, client: AsyncClient):
response = await client.get(app.url_path_for("get_templates"))
assert response.status_code == status.HTTP_200_OK
templates = [template for template in response.json() if template["builtin"] is False]
assert len(templates) == 2