.
This commit is contained in:
		@@ -1,31 +1,44 @@
 | 
			
		||||
"""
 | 
			
		||||
Tenant middleware for handling multi-tenant requests
 | 
			
		||||
"""
 | 
			
		||||
import uuid
 | 
			
		||||
from typing import Optional
 | 
			
		||||
from fastapi import Request, HTTPException, status
 | 
			
		||||
from starlette.middleware.base import BaseHTTPMiddleware
 | 
			
		||||
from sqlalchemy.orm import Session
 | 
			
		||||
 | 
			
		||||
from .config.settings import settings
 | 
			
		||||
from .models import Tenant
 | 
			
		||||
from .database import SessionLocal
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TenantMiddleware(BaseHTTPMiddleware):
 | 
			
		||||
    async def dispatch(self, request: Request, call_next):
 | 
			
		||||
        # Get tenant ID from header or subdomain
 | 
			
		||||
        # Get tenant ID from header
 | 
			
		||||
        tenant_id = request.headers.get(settings.TENANT_ID_HEADER)
 | 
			
		||||
        
 | 
			
		||||
        # If not in header, try to extract from subdomain
 | 
			
		||||
        if not tenant_id:
 | 
			
		||||
            # Try to extract tenant from subdomain
 | 
			
		||||
            host = request.headers.get("host", "")
 | 
			
		||||
            tenant_id = self.extract_tenant_from_host(host)
 | 
			
		||||
        
 | 
			
		||||
        if not tenant_id and settings.MULTI_TENANT_ENABLED:
 | 
			
		||||
        # Look up tenant in database
 | 
			
		||||
        tenant = None
 | 
			
		||||
        if tenant_id:
 | 
			
		||||
            db: Session = SessionLocal()
 | 
			
		||||
            try:
 | 
			
		||||
                tenant = db.query(Tenant).filter(Tenant.subdomain == tenant_id).first()
 | 
			
		||||
            finally:
 | 
			
		||||
                db.close()
 | 
			
		||||
        
 | 
			
		||||
        if settings.MULTI_TENANT_ENABLED and not tenant:
 | 
			
		||||
            raise HTTPException(
 | 
			
		||||
                status_code=status.HTTP_400_BAD_REQUEST,
 | 
			
		||||
                detail="Tenant ID is required"
 | 
			
		||||
                detail="Valid tenant ID is required"
 | 
			
		||||
            )
 | 
			
		||||
        
 | 
			
		||||
        # Attach tenant info to request
 | 
			
		||||
        request.state.tenant_id = tenant_id
 | 
			
		||||
        request.state.tenant = tenant
 | 
			
		||||
        request.state.tenant_id = tenant.id if tenant else None
 | 
			
		||||
        
 | 
			
		||||
        response = await call_next(request)
 | 
			
		||||
        return response
 | 
			
		||||
@@ -34,7 +47,9 @@ class TenantMiddleware(BaseHTTPMiddleware):
 | 
			
		||||
        """
 | 
			
		||||
        Extract tenant from host (subdomain.tenant.com)
 | 
			
		||||
        """
 | 
			
		||||
        # For now, return a default tenant or None
 | 
			
		||||
        # In a real implementation, you would parse the subdomain
 | 
			
		||||
        # and look up the corresponding tenant in the database
 | 
			
		||||
        return "default"
 | 
			
		||||
        import re
 | 
			
		||||
        # Match subdomain from host (e.g., "tenant1.example.com" -> "tenant1")
 | 
			
		||||
        match = re.match(r'^([^.]+)\.', host)
 | 
			
		||||
        if match:
 | 
			
		||||
            return match.group(1)
 | 
			
		||||
        return None
 | 
			
		||||
		Reference in New Issue
	
	Block a user