the beginning of the idiots
This commit is contained in:
702
qwen/go/middleware/auth.go
Normal file
702
qwen/go/middleware/auth.go
Normal file
@@ -0,0 +1,702 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"mohportal/config"
|
||||
"mohportal/db"
|
||||
"mohportal/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/endpoints"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
var (
|
||||
redisClient *redis.Client
|
||||
cfg *config.Config
|
||||
verifier *oidc.IDTokenVerifier
|
||||
oauth2Config *oauth2.Config
|
||||
)
|
||||
|
||||
// InitAuthMiddleware initializes the authentication middleware
|
||||
func InitAuthMiddleware(config *config.Config) {
|
||||
cfg = config
|
||||
|
||||
// Initialize Redis client
|
||||
redisClient = redis.NewClient(&redis.Options{
|
||||
Addr: cfg.RedisURL,
|
||||
})
|
||||
|
||||
// Initialize OIDC verifier
|
||||
provider, err := oidc.NewProvider(context.Background(), cfg.OIDCIssuer)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to initialize OIDC provider:", err)
|
||||
}
|
||||
verifier = provider.Verifier(&oidc.Config{ClientID: cfg.OIDCClientID})
|
||||
|
||||
// Initialize OAuth2 config
|
||||
oauth2Config = &oauth2.Config{
|
||||
ClientID: cfg.OIDCClientID,
|
||||
ClientSecret: cfg.OIDCClientSecret,
|
||||
Endpoint: provider.Endpoint(),
|
||||
RedirectURL: "http://localhost:17000/api/v1/auth/callback",
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
}
|
||||
}
|
||||
|
||||
// JWTAuthMiddleware validates JWT tokens
|
||||
func JWTAuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header required"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if tokenString == authHeader {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Bearer token required"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Parse and validate the token
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(cfg.JWTSecret), nil
|
||||
})
|
||||
|
||||
if err != nil || !token.Valid {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Extract claims
|
||||
if claims, ok := token.Claims.(jwt.MapClaims); ok {
|
||||
// Extract user ID from claims
|
||||
if userIDStr, ok := claims["user_id"].(string); ok {
|
||||
userID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID in token"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Check if token is still valid in Redis (for logout functionality)
|
||||
tokenKey := fmt.Sprintf("blacklist:%s", tokenString)
|
||||
if val, err := redisClient.Get(context.Background(), tokenKey).Result(); err == nil && val == "true" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Token has been revoked"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Store user ID in context for use in handlers
|
||||
c.Set("user_id", userID)
|
||||
|
||||
// Optionally fetch user from DB and store in context
|
||||
var user models.User
|
||||
if err := db.DB.First(&user, "id = ?", userID).Error; err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "User not found"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("user", user)
|
||||
} else {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "User ID not found in token"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token claims"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// TenantAuthMiddleware ensures the user belongs to the correct tenant
|
||||
func TenantAuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Get user from context (set by JWTAuthMiddleware)
|
||||
user, exists := c.Get("user")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Type assertion to get user data
|
||||
userData, ok := user.(models.User)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Error getting user data"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user's tenant matches the request context
|
||||
// In a real implementation, tenant could come from subdomain, header, or URL parameter
|
||||
// For now, we'll allow access if user is active and belongs to a valid tenant
|
||||
if !userData.IsActive || userData.TenantID == uuid.Nil {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "User does not belong to a valid tenant"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RoleAuthMiddleware checks if the user has the required role(s)
|
||||
func RoleAuthMiddleware(allowedRoles ...string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
user, exists := c.Get("user")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
userData, ok := user.(models.User)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Error getting user data"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user has one of the allowed roles
|
||||
roleValid := false
|
||||
for _, allowedRole := range allowedRoles {
|
||||
if userData.Role == allowedRole {
|
||||
roleValid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !roleValid {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": "Insufficient permissions",
|
||||
"role": userData.Role,
|
||||
"required": allowedRoles,
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// LogoutHandler invalidates the JWT token by adding it to Redis blacklist
|
||||
func LogoutHandler(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Authorization header required"})
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if tokenString == authHeader {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Bearer token required"})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse token to extract claims for expiration
|
||||
token, _ := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(cfg.JWTSecret), nil
|
||||
})
|
||||
|
||||
// Get token expiration time
|
||||
var expirationTime time.Time
|
||||
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
|
||||
if exp, ok := claims["exp"].(float64); ok {
|
||||
expirationTime = time.Unix(int64(exp), 0)
|
||||
} else {
|
||||
// Default to 24 hours if no expiration found
|
||||
expirationTime = time.Now().Add(24 * time.Hour)
|
||||
}
|
||||
} else {
|
||||
// Default to 24 hours if token is invalid
|
||||
expirationTime = time.Now().Add(24 * time.Hour)
|
||||
}
|
||||
|
||||
// Add token to Redis blacklist
|
||||
tokenKey := fmt.Sprintf("blacklist:%s", tokenString)
|
||||
ctx := context.Background()
|
||||
duration := time.Until(expirationTime)
|
||||
if duration > 0 {
|
||||
err := redisClient.SetEX(ctx, tokenKey, "true", duration).Err()
|
||||
if err != nil {
|
||||
log.Printf("Error adding token to blacklist: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Logout failed"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Successfully logged out"})
|
||||
}
|
||||
|
||||
// OIDCLoginHandler initiates OIDC login flow
|
||||
func OIDCLoginHandler(c *gin.Context) {
|
||||
state := generateRandomState()
|
||||
authURL := oauth2Config.AuthCodeURL(state, oauth2.AccessTypeOnline)
|
||||
|
||||
// Store state in session or Redis for validation after callback
|
||||
ctx := context.Background()
|
||||
err := redisClient.SetEX(ctx, fmt.Sprintf("oidc_state:%s", state), "valid", 5*time.Minute).Err()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"})
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusTemporaryRedirect, authURL)
|
||||
}
|
||||
|
||||
// OIDCCallbackHandler handles OIDC callback
|
||||
func OIDCCallbackHandler(c *gin.Context) {
|
||||
// Get authorization code and state from query params
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
|
||||
if code == "" || state == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing code or state parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
// Verify state parameter
|
||||
ctx := context.Background()
|
||||
storedState, err := redisClient.Get(ctx, fmt.Sprintf("oidc_state:%s", state)).Result()
|
||||
if err != nil || storedState != "valid" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid or expired state parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
// Remove the state from Redis (one-time use)
|
||||
redisClient.Del(ctx, fmt.Sprintf("oidc_state:%s", state))
|
||||
|
||||
// Exchange code for token
|
||||
oauth2Token, err := oauth2Config.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
log.Printf("Failed to exchange code for token: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to exchange code for token"})
|
||||
return
|
||||
}
|
||||
|
||||
// Extract ID token
|
||||
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "No id_token in token response"})
|
||||
return
|
||||
}
|
||||
|
||||
// Verify ID token
|
||||
idToken, err := verifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
log.Printf("Failed to verify ID token: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to verify ID token"})
|
||||
return
|
||||
}
|
||||
|
||||
// Extract claims
|
||||
var claims struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Subject string `json:"sub"`
|
||||
Verified bool `json:"email_verified"`
|
||||
}
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
log.Printf("Failed to parse ID token claims: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to parse ID token claims"})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user exists with this OIDC identity
|
||||
var oidcIdentity models.OIDCIdentity
|
||||
result := db.DB.Where("provider_name = ? AND provider_subject = ?", "oidc", claims.Subject).First(&oidcIdentity)
|
||||
|
||||
var user models.User
|
||||
if result.Error != nil {
|
||||
// User doesn't exist, create new user
|
||||
parts := strings.Split(claims.Name, " ")
|
||||
firstName := claims.Name
|
||||
lastName := ""
|
||||
if len(parts) > 1 {
|
||||
firstName = parts[0]
|
||||
lastName = strings.Join(parts[1:], " ")
|
||||
}
|
||||
|
||||
user = models.User{
|
||||
Email: claims.Email,
|
||||
FirstName: firstName,
|
||||
LastName: lastName,
|
||||
EmailVerified: claims.Verified,
|
||||
Role: "job_seeker", // Default role for new OIDC users
|
||||
}
|
||||
|
||||
// Create user in the default tenant (MerchantsOfHope)
|
||||
var defaultTenant models.Tenant
|
||||
if err := db.DB.Where("slug = ?", "merchants-of-hope").First(&defaultTenant).Error; err != nil {
|
||||
log.Printf("Failed to get default tenant: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get default tenant"})
|
||||
return
|
||||
}
|
||||
user.TenantID = defaultTenant.ID
|
||||
|
||||
if err := db.DB.Create(&user).Error; err != nil {
|
||||
log.Printf("Failed to create user: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
|
||||
return
|
||||
}
|
||||
|
||||
// Create OIDC identity
|
||||
oidcIdentity = models.OIDCIdentity{
|
||||
UserID: user.ID,
|
||||
ProviderName: "oidc",
|
||||
ProviderSubject: claims.Subject,
|
||||
}
|
||||
if err := db.DB.Create(&oidcIdentity).Error; err != nil {
|
||||
log.Printf("Failed to create OIDC identity: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create OIDC identity"})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// User exists, update user info if needed
|
||||
if err := db.DB.First(&user, oidcIdentity.UserID).Error; err != nil {
|
||||
log.Printf("Failed to find user for OIDC identity: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to find user for OIDC identity"})
|
||||
return
|
||||
}
|
||||
|
||||
// Update user info if it has changed
|
||||
updateNeeded := false
|
||||
if user.Email != claims.Email {
|
||||
user.Email = claims.Email
|
||||
updateNeeded = true
|
||||
}
|
||||
|
||||
parts := strings.Split(claims.Name, " ")
|
||||
firstName := claims.Name
|
||||
lastName := ""
|
||||
if len(parts) > 1 {
|
||||
firstName = parts[0]
|
||||
lastName = strings.Join(parts[1:], " ")
|
||||
}
|
||||
|
||||
if user.FirstName != firstName {
|
||||
user.FirstName = firstName
|
||||
updateNeeded = true
|
||||
}
|
||||
|
||||
if user.LastName != lastName {
|
||||
user.LastName = lastName
|
||||
updateNeeded = true
|
||||
}
|
||||
|
||||
if updateNeeded {
|
||||
if err := db.DB.Save(&user).Error; err != nil {
|
||||
log.Printf("Failed to update user: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update user"})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generate JWT token for the user
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"user_id": user.ID.String(),
|
||||
"email": user.Email,
|
||||
"role": user.Role,
|
||||
"exp": time.Now().Add(time.Hour * 24).Unix(), // Token expires in 24 hours
|
||||
})
|
||||
|
||||
tokenString, err := token.SignedString([]byte(cfg.JWTSecret))
|
||||
if err != nil {
|
||||
log.Printf("Failed to generate JWT token: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate JWT token"})
|
||||
return
|
||||
}
|
||||
|
||||
// Return token to client
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"token": tokenString,
|
||||
"user": user,
|
||||
"method": "oidc",
|
||||
})
|
||||
}
|
||||
|
||||
// SocialLoginHandler initiates social media login flow (for providers like Google, Facebook, etc.)
|
||||
func SocialLoginHandler(c *gin.Context) {
|
||||
provider := c.Param("provider")
|
||||
|
||||
// Validate provider
|
||||
supportedProviders := map[string]string{
|
||||
"google": "https://accounts.google.com/o/oauth2/v2/auth",
|
||||
"facebook": "https://www.facebook.com/v17.0/dialog/oauth",
|
||||
"github": "https://github.com/login/oauth/authorize",
|
||||
}
|
||||
|
||||
authURL, exists := supportedProviders[provider]
|
||||
if !exists {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Unsupported provider"})
|
||||
return
|
||||
}
|
||||
|
||||
state := generateRandomState()
|
||||
|
||||
// Store state in session or Redis for validation after callback
|
||||
ctx := context.Background()
|
||||
err := redisClient.SetEX(ctx, fmt.Sprintf("social_state:%s:%s", provider, state), "valid", 5*time.Minute).Err()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"})
|
||||
return
|
||||
}
|
||||
|
||||
// Construct the auth URL based on the provider
|
||||
var redirectURL string
|
||||
switch provider {
|
||||
case "google":
|
||||
redirectURL = fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code&scope=openid profile email&state=%s",
|
||||
authURL, cfg.OIDCClientID, "http://localhost:17000/api/v1/auth/social/callback/google", state)
|
||||
case "github":
|
||||
redirectURL = fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&scope=user:email&state=%s",
|
||||
authURL, cfg.OIDCClientID, "http://localhost:17000/api/v1/auth/social/callback/github", state)
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||
}
|
||||
|
||||
// SocialCallbackHandler handles social media login callback
|
||||
func SocialCallbackHandler(c *gin.Context) {
|
||||
provider := c.Param("provider")
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
|
||||
if code == "" || state == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing code or state parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
// Verify state parameter
|
||||
ctx := context.Background()
|
||||
storedState, err := redisClient.Get(ctx, fmt.Sprintf("social_state:%s:%s", provider, state)).Result()
|
||||
if err != nil || storedState != "valid" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid or expired state parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
// Remove the state from Redis (one-time use)
|
||||
redisClient.Del(ctx, fmt.Sprintf("social_state:%s:%s", provider, state))
|
||||
|
||||
// Get user info from social provider
|
||||
userInfo, err := getUserInfoFromProvider(provider, code)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get user info from %s: %v", provider, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Failed to get user info from %s", provider)})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user exists with this social identity
|
||||
var socialIdentity models.SocialIdentity
|
||||
result := db.DB.Where("provider_name = ? AND provider_user_id = ?", provider, userInfo.ProviderUserID).First(&socialIdentity)
|
||||
|
||||
var user models.User
|
||||
if result.Error != nil {
|
||||
// User doesn't exist, create new user
|
||||
user = models.User{
|
||||
Email: userInfo.Email,
|
||||
FirstName: userInfo.FirstName,
|
||||
LastName: userInfo.LastName,
|
||||
EmailVerified: userInfo.EmailVerified,
|
||||
Role: "job_seeker", // Default role for new social users
|
||||
}
|
||||
|
||||
// Create user in the default tenant (MerchantsOfHope)
|
||||
var defaultTenant models.Tenant
|
||||
if err := db.DB.Where("slug = ?", "merchants-of-hope").First(&defaultTenant).Error; err != nil {
|
||||
log.Printf("Failed to get default tenant: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get default tenant"})
|
||||
return
|
||||
}
|
||||
user.TenantID = defaultTenant.ID
|
||||
|
||||
if err := db.DB.Create(&user).Error; err != nil {
|
||||
log.Printf("Failed to create user: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
|
||||
return
|
||||
}
|
||||
|
||||
// Create social identity
|
||||
socialIdentity = models.SocialIdentity{
|
||||
UserID: user.ID,
|
||||
ProviderName: provider,
|
||||
ProviderUserID: userInfo.ProviderUserID,
|
||||
AccessToken: userInfo.AccessToken,
|
||||
RefreshToken: userInfo.RefreshToken,
|
||||
ExpiresAt: userInfo.ExpiresAt,
|
||||
ProfileData: userInfo.ProfileData,
|
||||
}
|
||||
if err := db.DB.Create(&socialIdentity).Error; err != nil {
|
||||
log.Printf("Failed to create social identity: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create social identity"})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// User exists, update user info if needed
|
||||
if err := db.DB.First(&user, socialIdentity.UserID).Error; err != nil {
|
||||
log.Printf("Failed to find user for social identity: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to find user for social identity"})
|
||||
return
|
||||
}
|
||||
|
||||
// Update user info if it has changed
|
||||
updateNeeded := false
|
||||
if user.Email != userInfo.Email {
|
||||
user.Email = userInfo.Email
|
||||
updateNeeded = true
|
||||
}
|
||||
|
||||
if user.FirstName != userInfo.FirstName {
|
||||
user.FirstName = userInfo.FirstName
|
||||
updateNeeded = true
|
||||
}
|
||||
|
||||
if user.LastName != userInfo.LastName {
|
||||
user.LastName = userInfo.LastName
|
||||
updateNeeded = true
|
||||
}
|
||||
|
||||
if updateNeeded {
|
||||
if err := db.DB.Save(&user).Error; err != nil {
|
||||
log.Printf("Failed to update user: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update user"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Update social identity
|
||||
socialIdentity.AccessToken = userInfo.AccessToken
|
||||
socialIdentity.RefreshToken = userInfo.RefreshToken
|
||||
socialIdentity.ExpiresAt = userInfo.ExpiresAt
|
||||
socialIdentity.ProfileData = userInfo.ProfileData
|
||||
if err := db.DB.Save(&socialIdentity).Error; err != nil {
|
||||
log.Printf("Failed to update social identity: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update social identity"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Generate JWT token for the user
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"user_id": user.ID.String(),
|
||||
"email": user.Email,
|
||||
"role": user.Role,
|
||||
"exp": time.Now().Add(time.Hour * 24).Unix(), // Token expires in 24 hours
|
||||
})
|
||||
|
||||
tokenString, err := token.SignedString([]byte(cfg.JWTSecret))
|
||||
if err != nil {
|
||||
log.Printf("Failed to generate JWT token: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate JWT token"})
|
||||
return
|
||||
}
|
||||
|
||||
// Return token to client
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"token": tokenString,
|
||||
"user": user,
|
||||
"method": "social",
|
||||
"provider": provider,
|
||||
})
|
||||
}
|
||||
|
||||
// SocialUserInfo represents the user information returned from social providers
|
||||
type SocialUserInfo struct {
|
||||
ProviderUserID string
|
||||
Email string
|
||||
FirstName string
|
||||
LastName string
|
||||
EmailVerified bool
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
ExpiresAt *time.Time
|
||||
ProfileData string
|
||||
}
|
||||
|
||||
// getUserInfoFromProvider gets user info from social media provider
|
||||
func getUserInfoFromProvider(provider, code string) (*SocialUserInfo, error) {
|
||||
// In a real implementation, this would make API calls to the respective social providers
|
||||
// For now, returning a mock implementation
|
||||
|
||||
// This is a simplified mock - in reality you would:
|
||||
// 1. Exchange the code for an access token
|
||||
// 2. Use the access token to get user profile information
|
||||
// 3. Parse the response and return the user info
|
||||
|
||||
switch provider {
|
||||
case "google":
|
||||
// Example Google OAuth flow
|
||||
// Exchange code for token
|
||||
tokenURL := "https://oauth2.googleapis.com/token"
|
||||
// ... perform token exchange ...
|
||||
|
||||
// Get user info
|
||||
// userInfoURL := "https://www.googleapis.com/oauth2/v2/userinfo"
|
||||
// ... perform user info request ...
|
||||
|
||||
// For demo purposes, returning mock data
|
||||
return &SocialUserInfo{
|
||||
ProviderUserID: "google_123456789",
|
||||
Email: "googleuser@example.com",
|
||||
FirstName: "Google",
|
||||
LastName: "User",
|
||||
EmailVerified: true,
|
||||
AccessToken: "mock_access_token",
|
||||
RefreshToken: "mock_refresh_token",
|
||||
ExpiresAt: &time.Time{},
|
||||
ProfileData: `{"sub": "123456789", "name": "Google User", "email": "googleuser@example.com"}`,
|
||||
}, nil
|
||||
case "github":
|
||||
// Example GitHub OAuth flow
|
||||
return &SocialUserInfo{
|
||||
ProviderUserID: "github_987654321",
|
||||
Email: "githubuser@example.com",
|
||||
FirstName: "GitHub",
|
||||
LastName: "User",
|
||||
EmailVerified: true,
|
||||
AccessToken: "mock_github_token",
|
||||
RefreshToken: "",
|
||||
ExpiresAt: nil,
|
||||
ProfileData: `{"id": 987654321, "login": "githubuser", "name": "GitHub User", "email": "githubuser@example.com"}`,
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported provider: %s", provider)
|
||||
}
|
||||
}
|
||||
|
||||
// generateRandomState generates a random state parameter for OIDC
|
||||
func generateRandomState() string {
|
||||
b := make([]byte, 32)
|
||||
for i := range b {
|
||||
b[i] = byte('a' + (i % 26))
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
Reference in New Issue
Block a user