Protect the API and add alternative authentication endpoint.

This commit is contained in:
grossmj
2021-04-20 11:59:02 +09:30
parent e28452f09a
commit 0465cb87f6
7 changed files with 199 additions and 80 deletions

View File

@ -14,7 +14,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from fastapi import APIRouter
from fastapi import APIRouter, Depends
from . import controller
from . import appliances
@ -30,17 +30,80 @@ from . import symbols
from . import templates
from . import users
from .dependencies.authentication import get_current_active_user
router = APIRouter()
router.include_router(controller.router, tags=["Controller"])
router.include_router(users.router, prefix="/users", tags=["Users"])
router.include_router(appliances.router, prefix="/appliances", tags=["Appliances"])
router.include_router(computes.router, prefix="/computes", tags=["Computes"])
router.include_router(drawings.router, prefix="/projects/{project_id}/drawings", tags=["Drawings"])
router.include_router(gns3vm.router, prefix="/gns3vm", tags=["GNS3 VM"])
router.include_router(links.router, prefix="/projects/{project_id}/links", tags=["Links"])
router.include_router(nodes.router, prefix="/projects/{project_id}/nodes", tags=["Nodes"])
router.include_router(notifications.router, prefix="/notifications", tags=["Notifications"])
router.include_router(projects.router, prefix="/projects", tags=["Projects"])
router.include_router(snapshots.router, prefix="/projects/{project_id}/snapshots", tags=["Snapshots"])
router.include_router(symbols.router, prefix="/symbols", tags=["Symbols"])
router.include_router(templates.router, tags=["Templates"])
router.include_router(
appliances.router,
dependencies=[Depends(get_current_active_user)],
prefix="/appliances",
tags=["Appliances"]
)
router.include_router(
computes.router,
dependencies=[Depends(get_current_active_user)],
prefix="/computes",
tags=["Computes"]
)
router.include_router(
drawings.router,
dependencies=[Depends(get_current_active_user)],
prefix="/projects/{project_id}/drawings",
tags=["Drawings"])
router.include_router(
gns3vm.router,
dependencies=[Depends(get_current_active_user)],
prefix="/gns3vm",
tags=["GNS3 VM"]
)
router.include_router(
links.router,
dependencies=[Depends(get_current_active_user)],
prefix="/projects/{project_id}/links",
tags=["Links"]
)
router.include_router(
nodes.router,
dependencies=[Depends(get_current_active_user)],
prefix="/projects/{project_id}/nodes",
tags=["Nodes"]
)
router.include_router(
notifications.router,
dependencies=[Depends(get_current_active_user)],
prefix="/notifications",
tags=["Notifications"])
router.include_router(
projects.router,
dependencies=[Depends(get_current_active_user)],
prefix="/projects",
tags=["Projects"])
router.include_router(
snapshots.router,
dependencies=[Depends(get_current_active_user)],
prefix="/projects/{project_id}/snapshots",
tags=["Snapshots"])
router.include_router(
symbols.router,
dependencies=[Depends(get_current_active_user)],
prefix="/symbols", tags=["Symbols"]
)
router.include_router(
templates.router,
dependencies=[Depends(get_current_active_user)],
tags=["Templates"]
)

View File

@ -18,7 +18,7 @@ import asyncio
import signal
import os
from fastapi import APIRouter, status
from fastapi import APIRouter, Depends, status
from fastapi.encoders import jsonable_encoder
from typing import List
@ -28,6 +28,7 @@ from gns3server.version import __version__
from gns3server.controller.controller_error import ControllerError, ControllerForbiddenError
from gns3server import schemas
from .dependencies.authentication import get_current_active_user
import logging
@ -36,8 +37,39 @@ log = logging.getLogger(__name__)
router = APIRouter()
@router.get(
"/version",
response_model=schemas.Version,
)
def get_version() -> dict:
"""
Return the server version number.
"""
local_server = Config.instance().settings.Server.local
return {"version": __version__, "local": local_server}
@router.post(
"/version",
response_model=schemas.Version,
response_model_exclude_defaults=True,
responses={409: {"model": schemas.ErrorMessage, "description": "Invalid version"}},
)
def check_version(version: schemas.Version) -> dict:
"""
Check if version is the same as the server.
"""
print(version.version)
if version.version != __version__:
raise ControllerError(f"Client version {version.version} is not the same as server version {__version__}")
return {"version": __version__}
@router.post(
"/shutdown",
dependencies=[Depends(get_current_active_user)],
status_code=status.HTTP_204_NO_CONTENT,
responses={403: {"model": schemas.ErrorMessage, "description": "Server shutdown not allowed"}},
)
@ -71,38 +103,11 @@ async def shutdown() -> None:
os.kill(os.getpid(), signal.SIGTERM)
@router.get("/version", response_model=schemas.Version)
def get_version() -> dict:
"""
Return the server version number.
"""
local_server = Config.instance().settings.Server.local
return {"version": __version__, "local": local_server}
@router.post(
"/version",
response_model=schemas.Version,
response_model_exclude_defaults=True,
responses={409: {"model": schemas.ErrorMessage, "description": "Invalid version"}},
@router.get(
"/iou_license",
dependencies=[Depends(get_current_active_user)],
response_model=schemas.IOULicense
)
def check_version(version: schemas.Version) -> dict:
"""
Check if version is the same as the server.
:param request:
:param response:
:return:
"""
print(version.version)
if version.version != __version__:
raise ControllerError(f"Client version {version.version} is not the same as server version {__version__}")
return {"version": __version__}
@router.get("/iou_license", response_model=schemas.IOULicense)
def get_iou_license() -> schemas.IOULicense:
"""
Return the IOU license settings
@ -111,7 +116,12 @@ def get_iou_license() -> schemas.IOULicense:
return Controller.instance().iou_license
@router.put("/iou_license", status_code=status.HTTP_201_CREATED, response_model=schemas.IOULicense)
@router.put(
"/iou_license",
dependencies=[Depends(get_current_active_user)],
status_code=status.HTTP_201_CREATED,
response_model=schemas.IOULicense
)
async def update_iou_license(iou_license: schemas.IOULicense) -> schemas.IOULicense:
"""
Update the IOU license settings.
@ -124,7 +134,7 @@ async def update_iou_license(iou_license: schemas.IOULicense) -> schemas.IOULice
return current_iou_license
@router.get("/statistics")
@router.get("/statistics", dependencies=[Depends(get_current_active_user)])
async def statistics() -> List[dict]:
"""
Return server statistics.

View File

@ -24,7 +24,7 @@ from gns3server.services import auth_service
from .database import get_repository
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v3/users/login") # FIXME: URL prefix
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v3/users/login")
async def get_user_from_token(

View File

@ -44,10 +44,53 @@ log = logging.getLogger(__name__)
router = APIRouter()
@router.get("", response_model=List[schemas.User])
@router.post("/login", response_model=schemas.Token)
async def login(
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
form_data: OAuth2PasswordRequestForm = Depends(),
) -> schemas.Token:
"""
Default user login method using forms (x-www-form-urlencoded).
Example: curl http://host:port/v3/users/login -H "Content-Type: application/x-www-form-urlencoded" -d "username=admin&password=admin"
"""
user = await users_repo.authenticate_user(username=form_data.username, password=form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication was unsuccessful.",
headers={"WWW-Authenticate": "Bearer"},
)
token = schemas.Token(access_token=auth_service.create_access_token(user.username), token_type="bearer")
return token
@router.post("/authenticate", response_model=schemas.Token)
async def authenticate(
user_credentials: schemas.Credentials,
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
) -> schemas.Token:
"""
Alternative authentication method using json.
Example: curl http://host:port/v3/users/authenticate -d '{"username": "admin", "password": "admin"}'
"""
user = await users_repo.authenticate_user(username=user_credentials.username, password=user_credentials.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication was unsuccessful.",
headers={"WWW-Authenticate": "Bearer"},
)
token = schemas.Token(access_token=auth_service.create_access_token(user.username), token_type="bearer")
return token
@router.get("", response_model=List[schemas.User], dependencies=[Depends(get_current_active_user)])
async def get_users(
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
current_user: schemas.User = Depends(get_current_active_user)
users_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> List[schemas.User]:
"""
Get all users.
@ -56,11 +99,15 @@ async def get_users(
return await users_repo.get_users()
@router.post("", response_model=schemas.User, status_code=status.HTTP_201_CREATED)
@router.post(
"",
response_model=schemas.User,
dependencies=[Depends(get_current_active_user)],
status_code=status.HTTP_201_CREATED
)
async def create_user(
user_create: schemas.UserCreate,
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
current_user: schemas.User = Depends(get_current_active_user)
users_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> schemas.User:
"""
Create a new user.
@ -75,11 +122,10 @@ async def create_user(
return await users_repo.create_user(user_create)
@router.get("/{user_id}",response_model=schemas.User)
@router.get("/{user_id}", dependencies=[Depends(get_current_active_user)], response_model=schemas.User)
async def get_user(
user_id: UUID,
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
current_user: schemas.User = Depends(get_current_active_user)
) -> schemas.User:
"""
Get an user.
@ -91,12 +137,11 @@ async def get_user(
return user
@router.put("/{user_id}", response_model=schemas.User)
@router.put("/{user_id}", dependencies=[Depends(get_current_active_user)], response_model=schemas.User)
async def update_user(
user_id: UUID,
user_update: schemas.UserUpdate,
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
current_user: schemas.User = Depends(get_current_active_user)
users_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> schemas.User:
"""
Update an user.
@ -126,27 +171,6 @@ async def delete_user(
raise ControllerNotFoundError(f"User '{user_id}' not found")
@router.post("/login", response_model=schemas.Token)
async def login(
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
form_data: OAuth2PasswordRequestForm = Depends(),
) -> schemas.Token:
"""
User login.
"""
user = await users_repo.authenticate_user(username=form_data.username, password=form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication was unsuccessful.",
headers={"WWW-Authenticate": "Bearer"},
)
token = schemas.Token(access_token=auth_service.create_access_token(user.username), token_type="bearer")
return token
@router.get("/users/me/", response_model=schemas.User)
async def get_current_active_user(current_user: schemas.User = Depends(get_current_active_user)) -> schemas.User:
"""