.
This commit is contained in:
@@ -1,31 +1,44 @@
|
||||
"""
|
||||
Tenant middleware for handling multi-tenant requests
|
||||
"""
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from fastapi import Request, HTTPException, status
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .config.settings import settings
|
||||
from .models import Tenant
|
||||
from .database import SessionLocal
|
||||
|
||||
|
||||
class TenantMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Get tenant ID from header or subdomain
|
||||
# Get tenant ID from header
|
||||
tenant_id = request.headers.get(settings.TENANT_ID_HEADER)
|
||||
|
||||
# If not in header, try to extract from subdomain
|
||||
if not tenant_id:
|
||||
# Try to extract tenant from subdomain
|
||||
host = request.headers.get("host", "")
|
||||
tenant_id = self.extract_tenant_from_host(host)
|
||||
|
||||
if not tenant_id and settings.MULTI_TENANT_ENABLED:
|
||||
# Look up tenant in database
|
||||
tenant = None
|
||||
if tenant_id:
|
||||
db: Session = SessionLocal()
|
||||
try:
|
||||
tenant = db.query(Tenant).filter(Tenant.subdomain == tenant_id).first()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
if settings.MULTI_TENANT_ENABLED and not tenant:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Tenant ID is required"
|
||||
detail="Valid tenant ID is required"
|
||||
)
|
||||
|
||||
# Attach tenant info to request
|
||||
request.state.tenant_id = tenant_id
|
||||
request.state.tenant = tenant
|
||||
request.state.tenant_id = tenant.id if tenant else None
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
@@ -34,7 +47,9 @@ class TenantMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Extract tenant from host (subdomain.tenant.com)
|
||||
"""
|
||||
# For now, return a default tenant or None
|
||||
# In a real implementation, you would parse the subdomain
|
||||
# and look up the corresponding tenant in the database
|
||||
return "default"
|
||||
import re
|
||||
# Match subdomain from host (e.g., "tenant1.example.com" -> "tenant1")
|
||||
match = re.match(r'^([^.]+)\.', host)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
Reference in New Issue
Block a user