Files
MOHPortalTest-AllAgents-All…/qwen/go/middleware/auth.go
2025-10-24 16:29:40 -05:00

701 lines
21 KiB
Go

package middleware
import (
"context"
"fmt"
"log"
"net/http"
"strings"
"time"
"mohportal/config"
"mohportal/db"
"mohportal/models"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"golang.org/x/oauth2"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/redis/go-redis/v9"
)
var (
redisClient *redis.Client
cfg *config.Config
verifier *oidc.IDTokenVerifier
oauth2Config *oauth2.Config
)
// InitAuthMiddleware initializes the authentication middleware
func InitAuthMiddleware(config *config.Config) {
cfg = config
// Initialize Redis client
redisClient = redis.NewClient(&redis.Options{
Addr: cfg.RedisURL,
})
// Initialize OIDC verifier
provider, err := oidc.NewProvider(context.Background(), cfg.OIDCIssuer)
if err != nil {
log.Fatal("Failed to initialize OIDC provider:", err)
}
verifier = provider.Verifier(&oidc.Config{ClientID: cfg.OIDCClientID})
// Initialize OAuth2 config
oauth2Config = &oauth2.Config{
ClientID: cfg.OIDCClientID,
ClientSecret: cfg.OIDCClientSecret,
Endpoint: provider.Endpoint(),
RedirectURL: "http://localhost:17000/api/v1/auth/callback",
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
}
// JWTAuthMiddleware validates JWT tokens
func JWTAuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header required"})
c.Abort()
return
}
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
if tokenString == authHeader {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Bearer token required"})
c.Abort()
return
}
// Parse and validate the token
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(cfg.JWTSecret), nil
})
if err != nil || !token.Valid {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
c.Abort()
return
}
// Extract claims
if claims, ok := token.Claims.(jwt.MapClaims); ok {
// Extract user ID from claims
if userIDStr, ok := claims["user_id"].(string); ok {
userID, err := uuid.Parse(userIDStr)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID in token"})
c.Abort()
return
}
// Check if token is still valid in Redis (for logout functionality)
tokenKey := fmt.Sprintf("blacklist:%s", tokenString)
if val, err := redisClient.Get(context.Background(), tokenKey).Result(); err == nil && val == "true" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Token has been revoked"})
c.Abort()
return
}
// Store user ID in context for use in handlers
c.Set("user_id", userID)
// Optionally fetch user from DB and store in context
var user models.User
if err := db.DB.First(&user, "id = ?", userID).Error; err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "User not found"})
c.Abort()
return
}
c.Set("user", user)
} else {
c.JSON(http.StatusUnauthorized, gin.H{"error": "User ID not found in token"})
c.Abort()
return
}
} else {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token claims"})
c.Abort()
return
}
c.Next()
}
}
// TenantAuthMiddleware ensures the user belongs to the correct tenant
func TenantAuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// Get user from context (set by JWTAuthMiddleware)
user, exists := c.Get("user")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
c.Abort()
return
}
// Type assertion to get user data
userData, ok := user.(models.User)
if !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Error getting user data"})
c.Abort()
return
}
// Check if user's tenant matches the request context
// In a real implementation, tenant could come from subdomain, header, or URL parameter
// For now, we'll allow access if user is active and belongs to a valid tenant
if !userData.IsActive || userData.TenantID == uuid.Nil {
c.JSON(http.StatusForbidden, gin.H{"error": "User does not belong to a valid tenant"})
c.Abort()
return
}
c.Next()
}
}
// RoleAuthMiddleware checks if the user has the required role(s)
func RoleAuthMiddleware(allowedRoles ...string) gin.HandlerFunc {
return func(c *gin.Context) {
user, exists := c.Get("user")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
c.Abort()
return
}
userData, ok := user.(models.User)
if !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Error getting user data"})
c.Abort()
return
}
// Check if user has one of the allowed roles
roleValid := false
for _, allowedRole := range allowedRoles {
if userData.Role == allowedRole {
roleValid = true
break
}
}
if !roleValid {
c.JSON(http.StatusForbidden, gin.H{
"error": "Insufficient permissions",
"role": userData.Role,
"required": allowedRoles,
})
c.Abort()
return
}
c.Next()
}
}
// LogoutHandler invalidates the JWT token by adding it to Redis blacklist
func LogoutHandler(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Authorization header required"})
return
}
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
if tokenString == authHeader {
c.JSON(http.StatusBadRequest, gin.H{"error": "Bearer token required"})
return
}
// Parse token to extract claims for expiration
token, _ := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(cfg.JWTSecret), nil
})
// Get token expiration time
var expirationTime time.Time
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
if exp, ok := claims["exp"].(float64); ok {
expirationTime = time.Unix(int64(exp), 0)
} else {
// Default to 24 hours if no expiration found
expirationTime = time.Now().Add(24 * time.Hour)
}
} else {
// Default to 24 hours if token is invalid
expirationTime = time.Now().Add(24 * time.Hour)
}
// Add token to Redis blacklist
tokenKey := fmt.Sprintf("blacklist:%s", tokenString)
ctx := context.Background()
duration := time.Until(expirationTime)
if duration > 0 {
err := redisClient.SetEx(ctx, tokenKey, "true", duration).Err()
if err != nil {
log.Printf("Error adding token to blacklist: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Logout failed"})
return
}
}
c.JSON(http.StatusOK, gin.H{"message": "Successfully logged out"})
}
// OIDCLoginHandler initiates OIDC login flow
func OIDCLoginHandler(c *gin.Context) {
state := generateRandomState()
authURL := oauth2Config.AuthCodeURL(state, oauth2.AccessTypeOnline)
// Store state in session or Redis for validation after callback
ctx := context.Background()
err := redisClient.SetEx(ctx, fmt.Sprintf("oidc_state:%s", state), "valid", 5*time.Minute).Err()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"})
return
}
c.Redirect(http.StatusTemporaryRedirect, authURL)
}
// OIDCCallbackHandler handles OIDC callback
func OIDCCallbackHandler(c *gin.Context) {
// Get authorization code and state from query params
code := c.Query("code")
state := c.Query("state")
if code == "" || state == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing code or state parameter"})
return
}
// Verify state parameter
ctx := context.Background()
storedState, err := redisClient.Get(ctx, fmt.Sprintf("oidc_state:%s", state)).Result()
if err != nil || storedState != "valid" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid or expired state parameter"})
return
}
// Remove the state from Redis (one-time use)
redisClient.Del(ctx, fmt.Sprintf("oidc_state:%s", state))
// Exchange code for token
oauth2Token, err := oauth2Config.Exchange(ctx, code)
if err != nil {
log.Printf("Failed to exchange code for token: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to exchange code for token"})
return
}
// Extract ID token
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": "No id_token in token response"})
return
}
// Verify ID token
idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil {
log.Printf("Failed to verify ID token: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to verify ID token"})
return
}
// Extract claims
var claims struct {
Email string `json:"email"`
Name string `json:"name"`
Subject string `json:"sub"`
Verified bool `json:"email_verified"`
}
if err := idToken.Claims(&claims); err != nil {
log.Printf("Failed to parse ID token claims: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to parse ID token claims"})
return
}
// Check if user exists with this OIDC identity
var oidcIdentity models.OIDCIdentity
result := db.DB.Where("provider_name = ? AND provider_subject = ?", "oidc", claims.Subject).First(&oidcIdentity)
var user models.User
if result.Error != nil {
// User doesn't exist, create new user
parts := strings.Split(claims.Name, " ")
firstName := claims.Name
lastName := ""
if len(parts) > 1 {
firstName = parts[0]
lastName = strings.Join(parts[1:], " ")
}
user = models.User{
Email: claims.Email,
FirstName: firstName,
LastName: lastName,
EmailVerified: claims.Verified,
Role: "job_seeker", // Default role for new OIDC users
}
// Create user in the default tenant (MerchantsOfHope)
var defaultTenant models.Tenant
if err := db.DB.Where("slug = ?", "merchants-of-hope").First(&defaultTenant).Error; err != nil {
log.Printf("Failed to get default tenant: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get default tenant"})
return
}
user.TenantID = defaultTenant.ID
if err := db.DB.Create(&user).Error; err != nil {
log.Printf("Failed to create user: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
return
}
// Create OIDC identity
oidcIdentity = models.OIDCIdentity{
UserID: user.ID,
ProviderName: "oidc",
ProviderSubject: claims.Subject,
}
if err := db.DB.Create(&oidcIdentity).Error; err != nil {
log.Printf("Failed to create OIDC identity: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create OIDC identity"})
return
}
} else {
// User exists, update user info if needed
if err := db.DB.First(&user, oidcIdentity.UserID).Error; err != nil {
log.Printf("Failed to find user for OIDC identity: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to find user for OIDC identity"})
return
}
// Update user info if it has changed
updateNeeded := false
if user.Email != claims.Email {
user.Email = claims.Email
updateNeeded = true
}
parts := strings.Split(claims.Name, " ")
firstName := claims.Name
lastName := ""
if len(parts) > 1 {
firstName = parts[0]
lastName = strings.Join(parts[1:], " ")
}
if user.FirstName != firstName {
user.FirstName = firstName
updateNeeded = true
}
if user.LastName != lastName {
user.LastName = lastName
updateNeeded = true
}
if updateNeeded {
if err := db.DB.Save(&user).Error; err != nil {
log.Printf("Failed to update user: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update user"})
return
}
}
}
// Generate JWT token for the user
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": user.ID.String(),
"email": user.Email,
"role": user.Role,
"exp": time.Now().Add(time.Hour * 24).Unix(), // Token expires in 24 hours
})
tokenString, err := token.SignedString([]byte(cfg.JWTSecret))
if err != nil {
log.Printf("Failed to generate JWT token: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate JWT token"})
return
}
// Return token to client
c.JSON(http.StatusOK, gin.H{
"token": tokenString,
"user": user,
"method": "oidc",
})
}
// SocialLoginHandler initiates social media login flow (for providers like Google, Facebook, etc.)
func SocialLoginHandler(c *gin.Context) {
provider := c.Param("provider")
// Validate provider
supportedProviders := map[string]string{
"google": "https://accounts.google.com/o/oauth2/v2/auth",
"facebook": "https://www.facebook.com/v17.0/dialog/oauth",
"github": "https://github.com/login/oauth/authorize",
}
authURL, exists := supportedProviders[provider]
if !exists {
c.JSON(http.StatusBadRequest, gin.H{"error": "Unsupported provider"})
return
}
state := generateRandomState()
// Store state in session or Redis for validation after callback
ctx := context.Background()
err := redisClient.SetEx(ctx, fmt.Sprintf("social_state:%s:%s", provider, state), "valid", 5*time.Minute).Err()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"})
return
}
// Construct the auth URL based on the provider
var redirectURL string
switch provider {
case "google":
redirectURL = fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code&scope=openid profile email&state=%s",
authURL, cfg.OIDCClientID, "http://localhost:17000/api/v1/auth/social/callback/google", state)
case "github":
redirectURL = fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&scope=user:email&state=%s",
authURL, cfg.OIDCClientID, "http://localhost:17000/api/v1/auth/social/callback/github", state)
}
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
}
// SocialCallbackHandler handles social media login callback
func SocialCallbackHandler(c *gin.Context) {
provider := c.Param("provider")
code := c.Query("code")
state := c.Query("state")
if code == "" || state == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing code or state parameter"})
return
}
// Verify state parameter
ctx := context.Background()
storedState, err := redisClient.Get(ctx, fmt.Sprintf("social_state:%s:%s", provider, state)).Result()
if err != nil || storedState != "valid" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid or expired state parameter"})
return
}
// Remove the state from Redis (one-time use)
redisClient.Del(ctx, fmt.Sprintf("social_state:%s:%s", provider, state))
// Get user info from social provider
userInfo, err := getUserInfoFromProvider(provider, code)
if err != nil {
log.Printf("Failed to get user info from %s: %v", provider, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Failed to get user info from %s", provider)})
return
}
// Check if user exists with this social identity
var socialIdentity models.SocialIdentity
result := db.DB.Where("provider_name = ? AND provider_user_id = ?", provider, userInfo.ProviderUserID).First(&socialIdentity)
var user models.User
if result.Error != nil {
// User doesn't exist, create new user
user = models.User{
Email: userInfo.Email,
FirstName: userInfo.FirstName,
LastName: userInfo.LastName,
EmailVerified: userInfo.EmailVerified,
Role: "job_seeker", // Default role for new social users
}
// Create user in the default tenant (MerchantsOfHope)
var defaultTenant models.Tenant
if err := db.DB.Where("slug = ?", "merchants-of-hope").First(&defaultTenant).Error; err != nil {
log.Printf("Failed to get default tenant: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get default tenant"})
return
}
user.TenantID = defaultTenant.ID
if err := db.DB.Create(&user).Error; err != nil {
log.Printf("Failed to create user: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
return
}
// Create social identity
socialIdentity = models.SocialIdentity{
UserID: user.ID,
ProviderName: provider,
ProviderUserID: userInfo.ProviderUserID,
AccessToken: userInfo.AccessToken,
RefreshToken: userInfo.RefreshToken,
ExpiresAt: userInfo.ExpiresAt,
ProfileData: userInfo.ProfileData,
}
if err := db.DB.Create(&socialIdentity).Error; err != nil {
log.Printf("Failed to create social identity: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create social identity"})
return
}
} else {
// User exists, update user info if needed
if err := db.DB.First(&user, socialIdentity.UserID).Error; err != nil {
log.Printf("Failed to find user for social identity: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to find user for social identity"})
return
}
// Update user info if it has changed
updateNeeded := false
if user.Email != userInfo.Email {
user.Email = userInfo.Email
updateNeeded = true
}
if user.FirstName != userInfo.FirstName {
user.FirstName = userInfo.FirstName
updateNeeded = true
}
if user.LastName != userInfo.LastName {
user.LastName = userInfo.LastName
updateNeeded = true
}
if updateNeeded {
if err := db.DB.Save(&user).Error; err != nil {
log.Printf("Failed to update user: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update user"})
return
}
}
// Update social identity
socialIdentity.AccessToken = userInfo.AccessToken
socialIdentity.RefreshToken = userInfo.RefreshToken
socialIdentity.ExpiresAt = userInfo.ExpiresAt
socialIdentity.ProfileData = userInfo.ProfileData
if err := db.DB.Save(&socialIdentity).Error; err != nil {
log.Printf("Failed to update social identity: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update social identity"})
return
}
}
// Generate JWT token for the user
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": user.ID.String(),
"email": user.Email,
"role": user.Role,
"exp": time.Now().Add(time.Hour * 24).Unix(), // Token expires in 24 hours
})
tokenString, err := token.SignedString([]byte(cfg.JWTSecret))
if err != nil {
log.Printf("Failed to generate JWT token: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate JWT token"})
return
}
// Return token to client
c.JSON(http.StatusOK, gin.H{
"token": tokenString,
"user": user,
"method": "social",
"provider": provider,
})
}
// SocialUserInfo represents the user information returned from social providers
type SocialUserInfo struct {
ProviderUserID string
Email string
FirstName string
LastName string
EmailVerified bool
AccessToken string
RefreshToken string
ExpiresAt *time.Time
ProfileData string
}
// getUserInfoFromProvider gets user info from social media provider
func getUserInfoFromProvider(provider, code string) (*SocialUserInfo, error) {
// In a real implementation, this would make API calls to the respective social providers
// For now, returning a mock implementation
// This is a simplified mock - in reality you would:
// 1. Exchange the code for an access token
// 2. Use the access token to get user profile information
// 3. Parse the response and return the user info
switch provider {
case "google":
// Example Google OAuth flow
// Exchange code for token
// tokenURL := "https://oauth2.googleapis.com/token"
// ... perform token exchange ...
// Get user info
// userInfoURL := "https://www.googleapis.com/oauth2/v2/userinfo"
// ... perform user info request ...
// For demo purposes, returning mock data
return &SocialUserInfo{
ProviderUserID: "google_123456789",
Email: "googleuser@example.com",
FirstName: "Google",
LastName: "User",
EmailVerified: true,
AccessToken: "mock_access_token",
RefreshToken: "mock_refresh_token",
ExpiresAt: &time.Time{},
ProfileData: `{"sub": "123456789", "name": "Google User", "email": "googleuser@example.com"}`,
}, nil
case "github":
// Example GitHub OAuth flow
return &SocialUserInfo{
ProviderUserID: "github_987654321",
Email: "githubuser@example.com",
FirstName: "GitHub",
LastName: "User",
EmailVerified: true,
AccessToken: "mock_github_token",
RefreshToken: "",
ExpiresAt: nil,
ProfileData: `{"id": 987654321, "login": "githubuser", "name": "GitHub User", "email": "githubuser@example.com"}`,
}, nil
default:
return nil, fmt.Errorf("unsupported provider: %s", provider)
}
}
// generateRandomState generates a random state parameter for OIDC
func generateRandomState() string {
b := make([]byte, 32)
for i := range b {
b[i] = byte('a' + (i % 26))
}
return string(b)
}