55 lines
1.8 KiB
Python
55 lines
1.8 KiB
Python
"""
|
|
Tenant middleware for handling multi-tenant requests
|
|
"""
|
|
from typing import Optional
|
|
from fastapi import Request, HTTPException, status
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from sqlalchemy.orm import Session
|
|
|
|
from .config.settings import settings
|
|
from .models import Tenant
|
|
from .database import SessionLocal
|
|
|
|
|
|
class TenantMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(self, request: Request, call_next):
|
|
# Get tenant ID from header
|
|
tenant_id = request.headers.get(settings.TENANT_ID_HEADER)
|
|
|
|
# If not in header, try to extract from subdomain
|
|
if not tenant_id:
|
|
host = request.headers.get("host", "")
|
|
tenant_id = self.extract_tenant_from_host(host)
|
|
|
|
# Look up tenant in database
|
|
tenant = None
|
|
if tenant_id:
|
|
db: Session = SessionLocal()
|
|
try:
|
|
tenant = db.query(Tenant).filter(Tenant.subdomain == tenant_id).first()
|
|
finally:
|
|
db.close()
|
|
|
|
if settings.MULTI_TENANT_ENABLED and not tenant:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Valid tenant ID is required"
|
|
)
|
|
|
|
# Attach tenant info to request
|
|
request.state.tenant = tenant
|
|
request.state.tenant_id = tenant.id if tenant else None
|
|
|
|
response = await call_next(request)
|
|
return response
|
|
|
|
def extract_tenant_from_host(self, host: str) -> Optional[str]:
|
|
"""
|
|
Extract tenant from host (subdomain.tenant.com)
|
|
"""
|
|
import re
|
|
# Match subdomain from host (e.g., "tenant1.example.com" -> "tenant1")
|
|
match = re.match(r'^([^.]+)\.', host)
|
|
if match:
|
|
return match.group(1)
|
|
return None |