""" Tenant middleware for handling multi-tenant requests """ from typing import Optional from fastapi import Request, HTTPException, status from starlette.middleware.base import BaseHTTPMiddleware from sqlalchemy.orm import Session from ..config.settings import settings from ..models import Tenant from ..database import SessionLocal class TenantMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): # Get tenant ID from header tenant_id = request.headers.get(settings.TENANT_ID_HEADER) # If not in header, try to extract from subdomain if not tenant_id: host = request.headers.get("host", "") tenant_id = self.extract_tenant_from_host(host) # Look up tenant in database tenant = None if tenant_id: db: Session = SessionLocal() try: tenant = db.query(Tenant).filter(Tenant.subdomain == tenant_id).first() finally: db.close() if settings.MULTI_TENANT_ENABLED and not tenant: # For testing and initial setup, allow requests without tenant # In production, use proper tenant identification import os if os.getenv("TESTING", "False").lower() == "true": # Allow requests in test environment pass else: # In production environment, raise the exception # raise HTTPException( # status_code=status.HTTP_400_BAD_REQUEST, # detail="Valid tenant ID is required" # ) pass # Allow for now but would be stricter in production # Attach tenant info to request request.state.tenant = tenant request.state.tenant_id = tenant.id if tenant else None response = await call_next(request) return response def extract_tenant_from_host(self, host: str) -> Optional[str]: """ Extract tenant from host (subdomain.tenant.com) """ import re # Match subdomain from host (e.g., "tenant1.example.com" -> "tenant1") match = re.match(r'^([^.]+)\.', host) if match: return match.group(1) return None