164 lines
5.8 KiB
Python
164 lines
5.8 KiB
Python
"""
|
|
Authentication API routes
|
|
"""
|
|
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
|
from fastapi.security import HTTPBearer
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional
|
|
from jose import jwt
|
|
from pydantic import BaseModel
|
|
from sqlalchemy.orm import Session
|
|
|
|
from ...config.settings import settings
|
|
from ...database import SessionLocal
|
|
from ...models import User
|
|
from ...services.auth_service import authenticate_user, create_access_token, get_oidc_config, handle_oidc_callback
|
|
|
|
router = APIRouter()
|
|
security = HTTPBearer()
|
|
|
|
# Pydantic models for auth
|
|
class Token(BaseModel):
|
|
access_token: str
|
|
token_type: str
|
|
|
|
class Config:
|
|
json_schema_extra = {
|
|
"example": {
|
|
"access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
|
|
"token_type": "bearer"
|
|
}
|
|
}
|
|
|
|
class TokenData(BaseModel):
|
|
username: Optional[str] = None
|
|
tenant_id: Optional[str] = None
|
|
|
|
class UserLogin(BaseModel):
|
|
username: str
|
|
password: str
|
|
|
|
class Config:
|
|
json_schema_extra = {
|
|
"example": {
|
|
"username": "johndoe",
|
|
"password": "securepassword"
|
|
}
|
|
}
|
|
|
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
|
to_encode = data.copy()
|
|
if expires_delta:
|
|
expire = datetime.utcnow() + expires_delta
|
|
else:
|
|
expire = datetime.utcnow() + timedelta(minutes=15)
|
|
to_encode.update({"exp": expire})
|
|
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
|
return encoded_jwt
|
|
|
|
@router.post("/token", response_model=Token)
|
|
async def login_for_access_token(form_data: UserLogin, db: Session = Depends(SessionLocal)):
|
|
user = authenticate_user(db, form_data.username, form_data.password)
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Incorrect username or password",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
access_token = create_access_token(
|
|
data={"sub": user.username, "tenant_id": user.tenant_id},
|
|
expires_delta=access_token_expires
|
|
)
|
|
return {"access_token": access_token, "token_type": "bearer"}
|
|
|
|
@router.get("/oidc-config")
|
|
async def get_oidc_config_endpoint():
|
|
"""Get OIDC configuration"""
|
|
return await get_oidc_config()
|
|
|
|
|
|
@router.get("/oidc-login")
|
|
async def oidc_login():
|
|
"""Initiate OIDC login flow"""
|
|
if not settings.OIDC_ISSUER or not settings.OIDC_CLIENT_ID or not settings.OIDC_REDIRECT_URI:
|
|
raise HTTPException(status_code=500, detail="OIDC not properly configured")
|
|
|
|
# Construct the authorization URL
|
|
auth_url = (
|
|
f"{settings.OIDC_ISSUER}/authorize?"
|
|
f"client_id={settings.OIDC_CLIENT_ID}&"
|
|
f"response_type=code&"
|
|
f"redirect_uri={settings.OIDC_REDIRECT_URI}&"
|
|
f"scope=openid profile email&"
|
|
f"state=random_state_string" # In real app, generate and store a proper state parameter
|
|
)
|
|
|
|
return {"auth_url": auth_url}
|
|
|
|
|
|
@router.get("/oidc-callback")
|
|
async def oidc_callback(code: str, state: str = None):
|
|
"""Handle OIDC callback"""
|
|
# Verify state parameter in a real implementation
|
|
|
|
# Handle the OIDC callback and return a user
|
|
try:
|
|
user = await handle_oidc_callback(code)
|
|
# Create access token for the user
|
|
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
access_token = create_access_token(
|
|
data={"sub": user.username, "tenant_id": user.tenant_id},
|
|
expires_delta=access_token_expires
|
|
)
|
|
return {"access_token": access_token, "token_type": "bearer"}
|
|
except NotImplementedError:
|
|
# For demo purposes, return a mock response
|
|
return {
|
|
"message": "OIDC callback received. In a real implementation, this would complete the login process.",
|
|
"code": code,
|
|
"state": state
|
|
}
|
|
|
|
|
|
@router.get("/social-login/{provider}")
|
|
async def social_login(provider: str):
|
|
"""Initiate social media login flow (Google, Facebook, etc.)"""
|
|
if provider not in ["google", "facebook", "github"]:
|
|
raise HTTPException(status_code=400, detail="Unsupported social provider")
|
|
|
|
# In a real implementation, redirect to the provider's OAuth URL
|
|
# For demo purposes, return a mock URL
|
|
auth_url = f"https://{provider}.com/oauth/authorize" # This is just a placeholder
|
|
return {
|
|
"message": f"Redirect user to {provider} for authentication",
|
|
"auth_url": auth_url,
|
|
"provider": provider
|
|
}
|
|
|
|
|
|
@router.post("/social-login/{provider}/callback")
|
|
async def social_login_callback(provider: str, access_token: str):
|
|
"""Handle social media login callback"""
|
|
if provider not in ["google", "facebook", "github"]:
|
|
raise HTTPException(status_code=400, detail="Unsupported social provider")
|
|
|
|
# In a real implementation, validate the access token and fetch user info
|
|
# For demo purposes, return a mock response
|
|
try:
|
|
user = await handle_social_login(provider, access_token)
|
|
# Create access token for the user
|
|
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
access_token = create_access_token(
|
|
data={"sub": user.username, "tenant_id": user.tenant_id},
|
|
expires_delta=access_token_expires
|
|
)
|
|
return {"access_token": access_token, "token_type": "bearer"}
|
|
except NotImplementedError:
|
|
# For demo purposes, return a mock response
|
|
return {
|
|
"message": "Social login callback received. In a real implementation, this would complete the login process.",
|
|
"provider": provider,
|
|
"access_token": access_token
|
|
} |