package middleware import ( "context" "fmt" "log" "net/http" "strings" "time" "mohportal/config" "mohportal/db" "mohportal/models" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "golang.org/x/oauth2" "github.com/coreos/go-oidc/v3/oidc" "github.com/redis/go-redis/v9" ) var ( redisClient *redis.Client cfg *config.Config verifier *oidc.IDTokenVerifier oauth2Config *oauth2.Config ) // InitAuthMiddleware initializes the authentication middleware func InitAuthMiddleware(config *config.Config) { cfg = config // Initialize Redis client redisClient = redis.NewClient(&redis.Options{ Addr: cfg.RedisURL, }) // Initialize OIDC verifier provider, err := oidc.NewProvider(context.Background(), cfg.OIDCIssuer) if err != nil { log.Fatal("Failed to initialize OIDC provider:", err) } verifier = provider.Verifier(&oidc.Config{ClientID: cfg.OIDCClientID}) // Initialize OAuth2 config oauth2Config = &oauth2.Config{ ClientID: cfg.OIDCClientID, ClientSecret: cfg.OIDCClientSecret, Endpoint: provider.Endpoint(), RedirectURL: "http://localhost:17000/api/v1/auth/callback", Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, } } // JWTAuthMiddleware validates JWT tokens func JWTAuthMiddleware() gin.HandlerFunc { return func(c *gin.Context) { authHeader := c.GetHeader("Authorization") if authHeader == "" { c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header required"}) c.Abort() return } tokenString := strings.TrimPrefix(authHeader, "Bearer ") if tokenString == authHeader { c.JSON(http.StatusUnauthorized, gin.H{"error": "Bearer token required"}) c.Abort() return } // Parse and validate the token token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return []byte(cfg.JWTSecret), nil }) if err != nil || !token.Valid { c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"}) c.Abort() return } // Extract claims if claims, ok := token.Claims.(jwt.MapClaims); ok { // Extract user ID from claims if userIDStr, ok := claims["user_id"].(string); ok { userID, err := uuid.Parse(userIDStr) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID in token"}) c.Abort() return } // Check if token is still valid in Redis (for logout functionality) tokenKey := fmt.Sprintf("blacklist:%s", tokenString) if val, err := redisClient.Get(context.Background(), tokenKey).Result(); err == nil && val == "true" { c.JSON(http.StatusUnauthorized, gin.H{"error": "Token has been revoked"}) c.Abort() return } // Store user ID in context for use in handlers c.Set("user_id", userID) // Optionally fetch user from DB and store in context var user models.User if err := db.DB.First(&user, "id = ?", userID).Error; err != nil { c.JSON(http.StatusUnauthorized, gin.H{"error": "User not found"}) c.Abort() return } c.Set("user", user) } else { c.JSON(http.StatusUnauthorized, gin.H{"error": "User ID not found in token"}) c.Abort() return } } else { c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token claims"}) c.Abort() return } c.Next() } } // TenantAuthMiddleware ensures the user belongs to the correct tenant func TenantAuthMiddleware() gin.HandlerFunc { return func(c *gin.Context) { // Get user from context (set by JWTAuthMiddleware) user, exists := c.Get("user") if !exists { c.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"}) c.Abort() return } // Type assertion to get user data userData, ok := user.(models.User) if !ok { c.JSON(http.StatusInternalServerError, gin.H{"error": "Error getting user data"}) c.Abort() return } // Check if user's tenant matches the request context // In a real implementation, tenant could come from subdomain, header, or URL parameter // For now, we'll allow access if user is active and belongs to a valid tenant if !userData.IsActive || userData.TenantID == uuid.Nil { c.JSON(http.StatusForbidden, gin.H{"error": "User does not belong to a valid tenant"}) c.Abort() return } c.Next() } } // RoleAuthMiddleware checks if the user has the required role(s) func RoleAuthMiddleware(allowedRoles ...string) gin.HandlerFunc { return func(c *gin.Context) { user, exists := c.Get("user") if !exists { c.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"}) c.Abort() return } userData, ok := user.(models.User) if !ok { c.JSON(http.StatusInternalServerError, gin.H{"error": "Error getting user data"}) c.Abort() return } // Check if user has one of the allowed roles roleValid := false for _, allowedRole := range allowedRoles { if userData.Role == allowedRole { roleValid = true break } } if !roleValid { c.JSON(http.StatusForbidden, gin.H{ "error": "Insufficient permissions", "role": userData.Role, "required": allowedRoles, }) c.Abort() return } c.Next() } } // LogoutHandler invalidates the JWT token by adding it to Redis blacklist func LogoutHandler(c *gin.Context) { authHeader := c.GetHeader("Authorization") if authHeader == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "Authorization header required"}) return } tokenString := strings.TrimPrefix(authHeader, "Bearer ") if tokenString == authHeader { c.JSON(http.StatusBadRequest, gin.H{"error": "Bearer token required"}) return } // Parse token to extract claims for expiration token, _ := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return []byte(cfg.JWTSecret), nil }) // Get token expiration time var expirationTime time.Time if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { if exp, ok := claims["exp"].(float64); ok { expirationTime = time.Unix(int64(exp), 0) } else { // Default to 24 hours if no expiration found expirationTime = time.Now().Add(24 * time.Hour) } } else { // Default to 24 hours if token is invalid expirationTime = time.Now().Add(24 * time.Hour) } // Add token to Redis blacklist tokenKey := fmt.Sprintf("blacklist:%s", tokenString) ctx := context.Background() duration := time.Until(expirationTime) if duration > 0 { err := redisClient.SetEx(ctx, tokenKey, "true", duration).Err() if err != nil { log.Printf("Error adding token to blacklist: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "Logout failed"}) return } } c.JSON(http.StatusOK, gin.H{"message": "Successfully logged out"}) } // OIDCLoginHandler initiates OIDC login flow func OIDCLoginHandler(c *gin.Context) { state := generateRandomState() authURL := oauth2Config.AuthCodeURL(state, oauth2.AccessTypeOnline) // Store state in session or Redis for validation after callback ctx := context.Background() err := redisClient.SetEx(ctx, fmt.Sprintf("oidc_state:%s", state), "valid", 5*time.Minute).Err() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"}) return } c.Redirect(http.StatusTemporaryRedirect, authURL) } // OIDCCallbackHandler handles OIDC callback func OIDCCallbackHandler(c *gin.Context) { // Get authorization code and state from query params code := c.Query("code") state := c.Query("state") if code == "" || state == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "Missing code or state parameter"}) return } // Verify state parameter ctx := context.Background() storedState, err := redisClient.Get(ctx, fmt.Sprintf("oidc_state:%s", state)).Result() if err != nil || storedState != "valid" { c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid or expired state parameter"}) return } // Remove the state from Redis (one-time use) redisClient.Del(ctx, fmt.Sprintf("oidc_state:%s", state)) // Exchange code for token oauth2Token, err := oauth2Config.Exchange(ctx, code) if err != nil { log.Printf("Failed to exchange code for token: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to exchange code for token"}) return } // Extract ID token rawIDToken, ok := oauth2Token.Extra("id_token").(string) if !ok { c.JSON(http.StatusInternalServerError, gin.H{"error": "No id_token in token response"}) return } // Verify ID token idToken, err := verifier.Verify(ctx, rawIDToken) if err != nil { log.Printf("Failed to verify ID token: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to verify ID token"}) return } // Extract claims var claims struct { Email string `json:"email"` Name string `json:"name"` Subject string `json:"sub"` Verified bool `json:"email_verified"` } if err := idToken.Claims(&claims); err != nil { log.Printf("Failed to parse ID token claims: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to parse ID token claims"}) return } // Check if user exists with this OIDC identity var oidcIdentity models.OIDCIdentity result := db.DB.Where("provider_name = ? AND provider_subject = ?", "oidc", claims.Subject).First(&oidcIdentity) var user models.User if result.Error != nil { // User doesn't exist, create new user parts := strings.Split(claims.Name, " ") firstName := claims.Name lastName := "" if len(parts) > 1 { firstName = parts[0] lastName = strings.Join(parts[1:], " ") } user = models.User{ Email: claims.Email, FirstName: firstName, LastName: lastName, EmailVerified: claims.Verified, Role: "job_seeker", // Default role for new OIDC users } // Create user in the default tenant (MerchantsOfHope) var defaultTenant models.Tenant if err := db.DB.Where("slug = ?", "merchants-of-hope").First(&defaultTenant).Error; err != nil { log.Printf("Failed to get default tenant: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get default tenant"}) return } user.TenantID = defaultTenant.ID if err := db.DB.Create(&user).Error; err != nil { log.Printf("Failed to create user: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"}) return } // Create OIDC identity oidcIdentity = models.OIDCIdentity{ UserID: user.ID, ProviderName: "oidc", ProviderSubject: claims.Subject, } if err := db.DB.Create(&oidcIdentity).Error; err != nil { log.Printf("Failed to create OIDC identity: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create OIDC identity"}) return } } else { // User exists, update user info if needed if err := db.DB.First(&user, oidcIdentity.UserID).Error; err != nil { log.Printf("Failed to find user for OIDC identity: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to find user for OIDC identity"}) return } // Update user info if it has changed updateNeeded := false if user.Email != claims.Email { user.Email = claims.Email updateNeeded = true } parts := strings.Split(claims.Name, " ") firstName := claims.Name lastName := "" if len(parts) > 1 { firstName = parts[0] lastName = strings.Join(parts[1:], " ") } if user.FirstName != firstName { user.FirstName = firstName updateNeeded = true } if user.LastName != lastName { user.LastName = lastName updateNeeded = true } if updateNeeded { if err := db.DB.Save(&user).Error; err != nil { log.Printf("Failed to update user: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update user"}) return } } } // Generate JWT token for the user token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "user_id": user.ID.String(), "email": user.Email, "role": user.Role, "exp": time.Now().Add(time.Hour * 24).Unix(), // Token expires in 24 hours }) tokenString, err := token.SignedString([]byte(cfg.JWTSecret)) if err != nil { log.Printf("Failed to generate JWT token: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate JWT token"}) return } // Return token to client c.JSON(http.StatusOK, gin.H{ "token": tokenString, "user": user, "method": "oidc", }) } // SocialLoginHandler initiates social media login flow (for providers like Google, Facebook, etc.) func SocialLoginHandler(c *gin.Context) { provider := c.Param("provider") // Validate provider supportedProviders := map[string]string{ "google": "https://accounts.google.com/o/oauth2/v2/auth", "facebook": "https://www.facebook.com/v17.0/dialog/oauth", "github": "https://github.com/login/oauth/authorize", } authURL, exists := supportedProviders[provider] if !exists { c.JSON(http.StatusBadRequest, gin.H{"error": "Unsupported provider"}) return } state := generateRandomState() // Store state in session or Redis for validation after callback ctx := context.Background() err := redisClient.SetEx(ctx, fmt.Sprintf("social_state:%s:%s", provider, state), "valid", 5*time.Minute).Err() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"}) return } // Construct the auth URL based on the provider var redirectURL string switch provider { case "google": redirectURL = fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code&scope=openid profile email&state=%s", authURL, cfg.OIDCClientID, "http://localhost:17000/api/v1/auth/social/callback/google", state) case "github": redirectURL = fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&scope=user:email&state=%s", authURL, cfg.OIDCClientID, "http://localhost:17000/api/v1/auth/social/callback/github", state) } c.Redirect(http.StatusTemporaryRedirect, redirectURL) } // SocialCallbackHandler handles social media login callback func SocialCallbackHandler(c *gin.Context) { provider := c.Param("provider") code := c.Query("code") state := c.Query("state") if code == "" || state == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "Missing code or state parameter"}) return } // Verify state parameter ctx := context.Background() storedState, err := redisClient.Get(ctx, fmt.Sprintf("social_state:%s:%s", provider, state)).Result() if err != nil || storedState != "valid" { c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid or expired state parameter"}) return } // Remove the state from Redis (one-time use) redisClient.Del(ctx, fmt.Sprintf("social_state:%s:%s", provider, state)) // Get user info from social provider userInfo, err := getUserInfoFromProvider(provider, code) if err != nil { log.Printf("Failed to get user info from %s: %v", provider, err) c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Failed to get user info from %s", provider)}) return } // Check if user exists with this social identity var socialIdentity models.SocialIdentity result := db.DB.Where("provider_name = ? AND provider_user_id = ?", provider, userInfo.ProviderUserID).First(&socialIdentity) var user models.User if result.Error != nil { // User doesn't exist, create new user user = models.User{ Email: userInfo.Email, FirstName: userInfo.FirstName, LastName: userInfo.LastName, EmailVerified: userInfo.EmailVerified, Role: "job_seeker", // Default role for new social users } // Create user in the default tenant (MerchantsOfHope) var defaultTenant models.Tenant if err := db.DB.Where("slug = ?", "merchants-of-hope").First(&defaultTenant).Error; err != nil { log.Printf("Failed to get default tenant: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get default tenant"}) return } user.TenantID = defaultTenant.ID if err := db.DB.Create(&user).Error; err != nil { log.Printf("Failed to create user: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"}) return } // Create social identity socialIdentity = models.SocialIdentity{ UserID: user.ID, ProviderName: provider, ProviderUserID: userInfo.ProviderUserID, AccessToken: userInfo.AccessToken, RefreshToken: userInfo.RefreshToken, ExpiresAt: userInfo.ExpiresAt, ProfileData: userInfo.ProfileData, } if err := db.DB.Create(&socialIdentity).Error; err != nil { log.Printf("Failed to create social identity: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create social identity"}) return } } else { // User exists, update user info if needed if err := db.DB.First(&user, socialIdentity.UserID).Error; err != nil { log.Printf("Failed to find user for social identity: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to find user for social identity"}) return } // Update user info if it has changed updateNeeded := false if user.Email != userInfo.Email { user.Email = userInfo.Email updateNeeded = true } if user.FirstName != userInfo.FirstName { user.FirstName = userInfo.FirstName updateNeeded = true } if user.LastName != userInfo.LastName { user.LastName = userInfo.LastName updateNeeded = true } if updateNeeded { if err := db.DB.Save(&user).Error; err != nil { log.Printf("Failed to update user: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update user"}) return } } // Update social identity socialIdentity.AccessToken = userInfo.AccessToken socialIdentity.RefreshToken = userInfo.RefreshToken socialIdentity.ExpiresAt = userInfo.ExpiresAt socialIdentity.ProfileData = userInfo.ProfileData if err := db.DB.Save(&socialIdentity).Error; err != nil { log.Printf("Failed to update social identity: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update social identity"}) return } } // Generate JWT token for the user token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "user_id": user.ID.String(), "email": user.Email, "role": user.Role, "exp": time.Now().Add(time.Hour * 24).Unix(), // Token expires in 24 hours }) tokenString, err := token.SignedString([]byte(cfg.JWTSecret)) if err != nil { log.Printf("Failed to generate JWT token: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate JWT token"}) return } // Return token to client c.JSON(http.StatusOK, gin.H{ "token": tokenString, "user": user, "method": "social", "provider": provider, }) } // SocialUserInfo represents the user information returned from social providers type SocialUserInfo struct { ProviderUserID string Email string FirstName string LastName string EmailVerified bool AccessToken string RefreshToken string ExpiresAt *time.Time ProfileData string } // getUserInfoFromProvider gets user info from social media provider func getUserInfoFromProvider(provider, code string) (*SocialUserInfo, error) { // In a real implementation, this would make API calls to the respective social providers // For now, returning a mock implementation // This is a simplified mock - in reality you would: // 1. Exchange the code for an access token // 2. Use the access token to get user profile information // 3. Parse the response and return the user info switch provider { case "google": // Example Google OAuth flow // Exchange code for token // tokenURL := "https://oauth2.googleapis.com/token" // ... perform token exchange ... // Get user info // userInfoURL := "https://www.googleapis.com/oauth2/v2/userinfo" // ... perform user info request ... // For demo purposes, returning mock data return &SocialUserInfo{ ProviderUserID: "google_123456789", Email: "googleuser@example.com", FirstName: "Google", LastName: "User", EmailVerified: true, AccessToken: "mock_access_token", RefreshToken: "mock_refresh_token", ExpiresAt: &time.Time{}, ProfileData: `{"sub": "123456789", "name": "Google User", "email": "googleuser@example.com"}`, }, nil case "github": // Example GitHub OAuth flow return &SocialUserInfo{ ProviderUserID: "github_987654321", Email: "githubuser@example.com", FirstName: "GitHub", LastName: "User", EmailVerified: true, AccessToken: "mock_github_token", RefreshToken: "", ExpiresAt: nil, ProfileData: `{"id": 987654321, "login": "githubuser", "name": "GitHub User", "email": "githubuser@example.com"}`, }, nil default: return nil, fmt.Errorf("unsupported provider: %s", provider) } } // generateRandomState generates a random state parameter for OIDC func generateRandomState() string { b := make([]byte, 32) for i := range b { b[i] = byte('a' + (i % 26)) } return string(b) }