package middleware import ( "net/http" "strings" "time" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/ydn/yourdreamnamehere/internal/config" ) type Claims struct { UserID string `json:"user_id"` Email string `json:"email"` Role string `json:"role"` jwt.RegisteredClaims } func AuthMiddleware(config *config.Config) gin.HandlerFunc { return func(c *gin.Context) { authHeader := c.GetHeader("Authorization") if authHeader == "" { c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header required"}) c.Abort() return } parts := strings.Split(authHeader, " ") if len(parts) != 2 || parts[0] != "Bearer" { c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid authorization header format"}) c.Abort() return } tokenString := parts[1] token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, jwt.ErrSignatureInvalid } return []byte(config.JWT.Secret), nil }) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"}) c.Abort() return } if claims, ok := token.Claims.(*Claims); ok && token.Valid { c.Set("user_id", claims.UserID) c.Set("email", claims.Email) c.Set("role", claims.Role) c.Next() } else { c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token claims"}) c.Abort() return } } } // RequireRole middleware for role-based access control func RequireRole(roles ...string) gin.HandlerFunc { return func(c *gin.Context) { userRole, exists := c.Get("role") if !exists { c.JSON(http.StatusForbidden, gin.H{"error": "User role not found"}) c.Abort() return } userRoleStr, ok := userRole.(string) if !ok { c.JSON(http.StatusForbidden, gin.H{"error": "Invalid user role format"}) c.Abort() return } // Check if user has required role hasRequiredRole := false for _, role := range roles { if userRoleStr == role || userRoleStr == "admin" { hasRequiredRole = true break } } if !hasRequiredRole { c.JSON(http.StatusForbidden, gin.H{"error": "Insufficient permissions"}) c.Abort() return } c.Next() } } func CORSMiddleware(config *config.Config) gin.HandlerFunc { return func(c *gin.Context) { origin := c.Request.Header.Get("Origin") // Check if origin is allowed allowed := false for _, allowedOrigin := range config.Security.CORSOrigins { if origin == allowedOrigin { allowed = true break } } if allowed { c.Header("Access-Control-Allow-Origin", origin) } c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization") c.Header("Access-Control-Expose-Headers", "Content-Length") c.Header("Access-Control-Allow-Credentials", "true") if c.Request.Method == "OPTIONS" { c.AbortWithStatus(http.StatusNoContent) return } c.Next() } } func RateLimitMiddleware(config *config.Config) gin.HandlerFunc { type client struct { lastRequest time.Time requests int } clients := make(map[string]*client) return func(c *gin.Context) { clientIP := c.ClientIP() now := time.Now() if cl, exists := clients[clientIP]; exists { // Reset window if expired if now.Sub(cl.lastRequest) > config.Security.RateLimitWindow { cl.requests = 0 cl.lastRequest = now } // Check rate limit if cl.requests >= config.Security.RateLimitRequests { c.JSON(http.StatusTooManyRequests, gin.H{ "error": "Rate limit exceeded", "retry_after": config.Security.RateLimitWindow.Seconds(), }) c.Abort() return } cl.requests++ } else { clients[clientIP] = &client{ lastRequest: now, requests: 1, } } c.Next() } } func LoggingMiddleware() gin.HandlerFunc { return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { return "" }) } func ErrorMiddleware() gin.HandlerFunc { return gin.Recovery() }