164 lines
		
	
	
		
			5.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			164 lines
		
	
	
		
			5.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""
 | 
						|
Authentication API routes
 | 
						|
"""
 | 
						|
from fastapi import APIRouter, Depends, HTTPException, status, Request
 | 
						|
from fastapi.security import HTTPBearer
 | 
						|
from datetime import datetime, timedelta
 | 
						|
from typing import Optional
 | 
						|
from jose import jwt
 | 
						|
from pydantic import BaseModel
 | 
						|
from sqlalchemy.orm import Session
 | 
						|
 | 
						|
from ...config.settings import settings
 | 
						|
from ...database import SessionLocal
 | 
						|
from ...models import User
 | 
						|
from ...services.auth_service import authenticate_user, create_access_token, get_oidc_config, handle_oidc_callback
 | 
						|
 | 
						|
router = APIRouter()
 | 
						|
security = HTTPBearer()
 | 
						|
 | 
						|
# Pydantic models for auth
 | 
						|
class Token(BaseModel):
 | 
						|
    access_token: str
 | 
						|
    token_type: str
 | 
						|
 | 
						|
    class Config:
 | 
						|
        json_schema_extra = {
 | 
						|
            "example": {
 | 
						|
                "access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
 | 
						|
                "token_type": "bearer"
 | 
						|
            }
 | 
						|
        }
 | 
						|
 | 
						|
class TokenData(BaseModel):
 | 
						|
    username: Optional[str] = None
 | 
						|
    tenant_id: Optional[str] = None
 | 
						|
 | 
						|
class UserLogin(BaseModel):
 | 
						|
    username: str
 | 
						|
    password: str
 | 
						|
 | 
						|
    class Config:
 | 
						|
        json_schema_extra = {
 | 
						|
            "example": {
 | 
						|
                "username": "johndoe",
 | 
						|
                "password": "securepassword"
 | 
						|
            }
 | 
						|
        }
 | 
						|
 | 
						|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
 | 
						|
    to_encode = data.copy()
 | 
						|
    if expires_delta:
 | 
						|
        expire = datetime.utcnow() + expires_delta
 | 
						|
    else:
 | 
						|
        expire = datetime.utcnow() + timedelta(minutes=15)
 | 
						|
    to_encode.update({"exp": expire})
 | 
						|
    encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
 | 
						|
    return encoded_jwt
 | 
						|
 | 
						|
@router.post("/token", response_model=Token)
 | 
						|
async def login_for_access_token(form_data: UserLogin, db: Session = Depends(SessionLocal)):
 | 
						|
    user = authenticate_user(db, form_data.username, form_data.password)
 | 
						|
    if not user:
 | 
						|
        raise HTTPException(
 | 
						|
            status_code=status.HTTP_401_UNAUTHORIZED,
 | 
						|
            detail="Incorrect username or password",
 | 
						|
            headers={"WWW-Authenticate": "Bearer"},
 | 
						|
        )
 | 
						|
    
 | 
						|
    access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
 | 
						|
    access_token = create_access_token(
 | 
						|
        data={"sub": user.username, "tenant_id": user.tenant_id},
 | 
						|
        expires_delta=access_token_expires
 | 
						|
    )
 | 
						|
    return {"access_token": access_token, "token_type": "bearer"}
 | 
						|
 | 
						|
@router.get("/oidc-config")
 | 
						|
async def get_oidc_config_endpoint():
 | 
						|
    """Get OIDC configuration"""
 | 
						|
    return await get_oidc_config()
 | 
						|
 | 
						|
 | 
						|
@router.get("/oidc-login")
 | 
						|
async def oidc_login():
 | 
						|
    """Initiate OIDC login flow"""
 | 
						|
    if not settings.OIDC_ISSUER or not settings.OIDC_CLIENT_ID or not settings.OIDC_REDIRECT_URI:
 | 
						|
        raise HTTPException(status_code=500, detail="OIDC not properly configured")
 | 
						|
    
 | 
						|
    # Construct the authorization URL
 | 
						|
    auth_url = (
 | 
						|
        f"{settings.OIDC_ISSUER}/authorize?"
 | 
						|
        f"client_id={settings.OIDC_CLIENT_ID}&"
 | 
						|
        f"response_type=code&"
 | 
						|
        f"redirect_uri={settings.OIDC_REDIRECT_URI}&"
 | 
						|
        f"scope=openid profile email&"
 | 
						|
        f"state=random_state_string"  # In real app, generate and store a proper state parameter
 | 
						|
    )
 | 
						|
    
 | 
						|
    return {"auth_url": auth_url}
 | 
						|
 | 
						|
 | 
						|
@router.get("/oidc-callback")
 | 
						|
async def oidc_callback(code: str, state: str = None):
 | 
						|
    """Handle OIDC callback"""
 | 
						|
    # Verify state parameter in a real implementation
 | 
						|
    
 | 
						|
    # Handle the OIDC callback and return a user
 | 
						|
    try:
 | 
						|
        user = await handle_oidc_callback(code)
 | 
						|
        # Create access token for the user
 | 
						|
        access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
 | 
						|
        access_token = create_access_token(
 | 
						|
            data={"sub": user.username, "tenant_id": user.tenant_id},
 | 
						|
            expires_delta=access_token_expires
 | 
						|
        )
 | 
						|
        return {"access_token": access_token, "token_type": "bearer"}
 | 
						|
    except NotImplementedError:
 | 
						|
        # For demo purposes, return a mock response
 | 
						|
        return {
 | 
						|
            "message": "OIDC callback received. In a real implementation, this would complete the login process.",
 | 
						|
            "code": code,
 | 
						|
            "state": state
 | 
						|
        }
 | 
						|
 | 
						|
 | 
						|
@router.get("/social-login/{provider}")
 | 
						|
async def social_login(provider: str):
 | 
						|
    """Initiate social media login flow (Google, Facebook, etc.)"""
 | 
						|
    if provider not in ["google", "facebook", "github"]:
 | 
						|
        raise HTTPException(status_code=400, detail="Unsupported social provider")
 | 
						|
    
 | 
						|
    # In a real implementation, redirect to the provider's OAuth URL
 | 
						|
    # For demo purposes, return a mock URL
 | 
						|
    auth_url = f"https://{provider}.com/oauth/authorize"  # This is just a placeholder
 | 
						|
    return {
 | 
						|
        "message": f"Redirect user to {provider} for authentication",
 | 
						|
        "auth_url": auth_url,
 | 
						|
        "provider": provider
 | 
						|
    }
 | 
						|
 | 
						|
 | 
						|
@router.post("/social-login/{provider}/callback")
 | 
						|
async def social_login_callback(provider: str, access_token: str):
 | 
						|
    """Handle social media login callback"""
 | 
						|
    if provider not in ["google", "facebook", "github"]:
 | 
						|
        raise HTTPException(status_code=400, detail="Unsupported social provider")
 | 
						|
    
 | 
						|
    # In a real implementation, validate the access token and fetch user info
 | 
						|
    # For demo purposes, return a mock response
 | 
						|
    try:
 | 
						|
        user = await handle_social_login(provider, access_token)
 | 
						|
        # Create access token for the user
 | 
						|
        access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
 | 
						|
        access_token = create_access_token(
 | 
						|
            data={"sub": user.username, "tenant_id": user.tenant_id},
 | 
						|
            expires_delta=access_token_expires
 | 
						|
        )
 | 
						|
        return {"access_token": access_token, "token_type": "bearer"}
 | 
						|
    except NotImplementedError:
 | 
						|
        # For demo purposes, return a mock response
 | 
						|
        return {
 | 
						|
            "message": "Social login callback received. In a real implementation, this would complete the login process.",
 | 
						|
            "provider": provider,
 | 
						|
            "access_token": access_token
 | 
						|
        } |