701 lines
21 KiB
Go
701 lines
21 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"mohportal/config"
|
|
"mohportal/db"
|
|
"mohportal/models"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/google/uuid"
|
|
"golang.org/x/oauth2"
|
|
"github.com/coreos/go-oidc/v3/oidc"
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
var (
|
|
redisClient *redis.Client
|
|
cfg *config.Config
|
|
verifier *oidc.IDTokenVerifier
|
|
oauth2Config *oauth2.Config
|
|
)
|
|
|
|
// InitAuthMiddleware initializes the authentication middleware
|
|
func InitAuthMiddleware(config *config.Config) {
|
|
cfg = config
|
|
|
|
// Initialize Redis client
|
|
redisClient = redis.NewClient(&redis.Options{
|
|
Addr: cfg.RedisURL,
|
|
})
|
|
|
|
// Initialize OIDC verifier
|
|
provider, err := oidc.NewProvider(context.Background(), cfg.OIDCIssuer)
|
|
if err != nil {
|
|
log.Fatal("Failed to initialize OIDC provider:", err)
|
|
}
|
|
verifier = provider.Verifier(&oidc.Config{ClientID: cfg.OIDCClientID})
|
|
|
|
// Initialize OAuth2 config
|
|
oauth2Config = &oauth2.Config{
|
|
ClientID: cfg.OIDCClientID,
|
|
ClientSecret: cfg.OIDCClientSecret,
|
|
Endpoint: provider.Endpoint(),
|
|
RedirectURL: "http://localhost:17000/api/v1/auth/callback",
|
|
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
|
}
|
|
}
|
|
|
|
// JWTAuthMiddleware validates JWT tokens
|
|
func JWTAuthMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
authHeader := c.GetHeader("Authorization")
|
|
if authHeader == "" {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header required"})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
|
if tokenString == authHeader {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "Bearer token required"})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// Parse and validate the token
|
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
return []byte(cfg.JWTSecret), nil
|
|
})
|
|
|
|
if err != nil || !token.Valid {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// Extract claims
|
|
if claims, ok := token.Claims.(jwt.MapClaims); ok {
|
|
// Extract user ID from claims
|
|
if userIDStr, ok := claims["user_id"].(string); ok {
|
|
userID, err := uuid.Parse(userIDStr)
|
|
if err != nil {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID in token"})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// Check if token is still valid in Redis (for logout functionality)
|
|
tokenKey := fmt.Sprintf("blacklist:%s", tokenString)
|
|
if val, err := redisClient.Get(context.Background(), tokenKey).Result(); err == nil && val == "true" {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "Token has been revoked"})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// Store user ID in context for use in handlers
|
|
c.Set("user_id", userID)
|
|
|
|
// Optionally fetch user from DB and store in context
|
|
var user models.User
|
|
if err := db.DB.First(&user, "id = ?", userID).Error; err != nil {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "User not found"})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
c.Set("user", user)
|
|
} else {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "User ID not found in token"})
|
|
c.Abort()
|
|
return
|
|
}
|
|
} else {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token claims"})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// TenantAuthMiddleware ensures the user belongs to the correct tenant
|
|
func TenantAuthMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
// Get user from context (set by JWTAuthMiddleware)
|
|
user, exists := c.Get("user")
|
|
if !exists {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// Type assertion to get user data
|
|
userData, ok := user.(models.User)
|
|
if !ok {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Error getting user data"})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// Check if user's tenant matches the request context
|
|
// In a real implementation, tenant could come from subdomain, header, or URL parameter
|
|
// For now, we'll allow access if user is active and belongs to a valid tenant
|
|
if !userData.IsActive || userData.TenantID == uuid.Nil {
|
|
c.JSON(http.StatusForbidden, gin.H{"error": "User does not belong to a valid tenant"})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// RoleAuthMiddleware checks if the user has the required role(s)
|
|
func RoleAuthMiddleware(allowedRoles ...string) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
user, exists := c.Get("user")
|
|
if !exists {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
userData, ok := user.(models.User)
|
|
if !ok {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Error getting user data"})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// Check if user has one of the allowed roles
|
|
roleValid := false
|
|
for _, allowedRole := range allowedRoles {
|
|
if userData.Role == allowedRole {
|
|
roleValid = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !roleValid {
|
|
c.JSON(http.StatusForbidden, gin.H{
|
|
"error": "Insufficient permissions",
|
|
"role": userData.Role,
|
|
"required": allowedRoles,
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// LogoutHandler invalidates the JWT token by adding it to Redis blacklist
|
|
func LogoutHandler(c *gin.Context) {
|
|
authHeader := c.GetHeader("Authorization")
|
|
if authHeader == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Authorization header required"})
|
|
return
|
|
}
|
|
|
|
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
|
if tokenString == authHeader {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Bearer token required"})
|
|
return
|
|
}
|
|
|
|
// Parse token to extract claims for expiration
|
|
token, _ := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
return []byte(cfg.JWTSecret), nil
|
|
})
|
|
|
|
// Get token expiration time
|
|
var expirationTime time.Time
|
|
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
|
|
if exp, ok := claims["exp"].(float64); ok {
|
|
expirationTime = time.Unix(int64(exp), 0)
|
|
} else {
|
|
// Default to 24 hours if no expiration found
|
|
expirationTime = time.Now().Add(24 * time.Hour)
|
|
}
|
|
} else {
|
|
// Default to 24 hours if token is invalid
|
|
expirationTime = time.Now().Add(24 * time.Hour)
|
|
}
|
|
|
|
// Add token to Redis blacklist
|
|
tokenKey := fmt.Sprintf("blacklist:%s", tokenString)
|
|
ctx := context.Background()
|
|
duration := time.Until(expirationTime)
|
|
if duration > 0 {
|
|
err := redisClient.SetEx(ctx, tokenKey, "true", duration).Err()
|
|
if err != nil {
|
|
log.Printf("Error adding token to blacklist: %v", err)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Logout failed"})
|
|
return
|
|
}
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{"message": "Successfully logged out"})
|
|
}
|
|
|
|
// OIDCLoginHandler initiates OIDC login flow
|
|
func OIDCLoginHandler(c *gin.Context) {
|
|
state := generateRandomState()
|
|
authURL := oauth2Config.AuthCodeURL(state, oauth2.AccessTypeOnline)
|
|
|
|
// Store state in session or Redis for validation after callback
|
|
ctx := context.Background()
|
|
err := redisClient.SetEx(ctx, fmt.Sprintf("oidc_state:%s", state), "valid", 5*time.Minute).Err()
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"})
|
|
return
|
|
}
|
|
|
|
c.Redirect(http.StatusTemporaryRedirect, authURL)
|
|
}
|
|
|
|
// OIDCCallbackHandler handles OIDC callback
|
|
func OIDCCallbackHandler(c *gin.Context) {
|
|
// Get authorization code and state from query params
|
|
code := c.Query("code")
|
|
state := c.Query("state")
|
|
|
|
if code == "" || state == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing code or state parameter"})
|
|
return
|
|
}
|
|
|
|
// Verify state parameter
|
|
ctx := context.Background()
|
|
storedState, err := redisClient.Get(ctx, fmt.Sprintf("oidc_state:%s", state)).Result()
|
|
if err != nil || storedState != "valid" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid or expired state parameter"})
|
|
return
|
|
}
|
|
|
|
// Remove the state from Redis (one-time use)
|
|
redisClient.Del(ctx, fmt.Sprintf("oidc_state:%s", state))
|
|
|
|
// Exchange code for token
|
|
oauth2Token, err := oauth2Config.Exchange(ctx, code)
|
|
if err != nil {
|
|
log.Printf("Failed to exchange code for token: %v", err)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to exchange code for token"})
|
|
return
|
|
}
|
|
|
|
// Extract ID token
|
|
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
|
if !ok {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "No id_token in token response"})
|
|
return
|
|
}
|
|
|
|
// Verify ID token
|
|
idToken, err := verifier.Verify(ctx, rawIDToken)
|
|
if err != nil {
|
|
log.Printf("Failed to verify ID token: %v", err)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to verify ID token"})
|
|
return
|
|
}
|
|
|
|
// Extract claims
|
|
var claims struct {
|
|
Email string `json:"email"`
|
|
Name string `json:"name"`
|
|
Subject string `json:"sub"`
|
|
Verified bool `json:"email_verified"`
|
|
}
|
|
if err := idToken.Claims(&claims); err != nil {
|
|
log.Printf("Failed to parse ID token claims: %v", err)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to parse ID token claims"})
|
|
return
|
|
}
|
|
|
|
// Check if user exists with this OIDC identity
|
|
var oidcIdentity models.OIDCIdentity
|
|
result := db.DB.Where("provider_name = ? AND provider_subject = ?", "oidc", claims.Subject).First(&oidcIdentity)
|
|
|
|
var user models.User
|
|
if result.Error != nil {
|
|
// User doesn't exist, create new user
|
|
parts := strings.Split(claims.Name, " ")
|
|
firstName := claims.Name
|
|
lastName := ""
|
|
if len(parts) > 1 {
|
|
firstName = parts[0]
|
|
lastName = strings.Join(parts[1:], " ")
|
|
}
|
|
|
|
user = models.User{
|
|
Email: claims.Email,
|
|
FirstName: firstName,
|
|
LastName: lastName,
|
|
EmailVerified: claims.Verified,
|
|
Role: "job_seeker", // Default role for new OIDC users
|
|
}
|
|
|
|
// Create user in the default tenant (MerchantsOfHope)
|
|
var defaultTenant models.Tenant
|
|
if err := db.DB.Where("slug = ?", "merchants-of-hope").First(&defaultTenant).Error; err != nil {
|
|
log.Printf("Failed to get default tenant: %v", err)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get default tenant"})
|
|
return
|
|
}
|
|
user.TenantID = defaultTenant.ID
|
|
|
|
if err := db.DB.Create(&user).Error; err != nil {
|
|
log.Printf("Failed to create user: %v", err)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
|
|
return
|
|
}
|
|
|
|
// Create OIDC identity
|
|
oidcIdentity = models.OIDCIdentity{
|
|
UserID: user.ID,
|
|
ProviderName: "oidc",
|
|
ProviderSubject: claims.Subject,
|
|
}
|
|
if err := db.DB.Create(&oidcIdentity).Error; err != nil {
|
|
log.Printf("Failed to create OIDC identity: %v", err)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create OIDC identity"})
|
|
return
|
|
}
|
|
} else {
|
|
// User exists, update user info if needed
|
|
if err := db.DB.First(&user, oidcIdentity.UserID).Error; err != nil {
|
|
log.Printf("Failed to find user for OIDC identity: %v", err)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to find user for OIDC identity"})
|
|
return
|
|
}
|
|
|
|
// Update user info if it has changed
|
|
updateNeeded := false
|
|
if user.Email != claims.Email {
|
|
user.Email = claims.Email
|
|
updateNeeded = true
|
|
}
|
|
|
|
parts := strings.Split(claims.Name, " ")
|
|
firstName := claims.Name
|
|
lastName := ""
|
|
if len(parts) > 1 {
|
|
firstName = parts[0]
|
|
lastName = strings.Join(parts[1:], " ")
|
|
}
|
|
|
|
if user.FirstName != firstName {
|
|
user.FirstName = firstName
|
|
updateNeeded = true
|
|
}
|
|
|
|
if user.LastName != lastName {
|
|
user.LastName = lastName
|
|
updateNeeded = true
|
|
}
|
|
|
|
if updateNeeded {
|
|
if err := db.DB.Save(&user).Error; err != nil {
|
|
log.Printf("Failed to update user: %v", err)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update user"})
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Generate JWT token for the user
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
|
"user_id": user.ID.String(),
|
|
"email": user.Email,
|
|
"role": user.Role,
|
|
"exp": time.Now().Add(time.Hour * 24).Unix(), // Token expires in 24 hours
|
|
})
|
|
|
|
tokenString, err := token.SignedString([]byte(cfg.JWTSecret))
|
|
if err != nil {
|
|
log.Printf("Failed to generate JWT token: %v", err)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate JWT token"})
|
|
return
|
|
}
|
|
|
|
// Return token to client
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"token": tokenString,
|
|
"user": user,
|
|
"method": "oidc",
|
|
})
|
|
}
|
|
|
|
// SocialLoginHandler initiates social media login flow (for providers like Google, Facebook, etc.)
|
|
func SocialLoginHandler(c *gin.Context) {
|
|
provider := c.Param("provider")
|
|
|
|
// Validate provider
|
|
supportedProviders := map[string]string{
|
|
"google": "https://accounts.google.com/o/oauth2/v2/auth",
|
|
"facebook": "https://www.facebook.com/v17.0/dialog/oauth",
|
|
"github": "https://github.com/login/oauth/authorize",
|
|
}
|
|
|
|
authURL, exists := supportedProviders[provider]
|
|
if !exists {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Unsupported provider"})
|
|
return
|
|
}
|
|
|
|
state := generateRandomState()
|
|
|
|
// Store state in session or Redis for validation after callback
|
|
ctx := context.Background()
|
|
err := redisClient.SetEx(ctx, fmt.Sprintf("social_state:%s:%s", provider, state), "valid", 5*time.Minute).Err()
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"})
|
|
return
|
|
}
|
|
|
|
// Construct the auth URL based on the provider
|
|
var redirectURL string
|
|
switch provider {
|
|
case "google":
|
|
redirectURL = fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code&scope=openid profile email&state=%s",
|
|
authURL, cfg.OIDCClientID, "http://localhost:17000/api/v1/auth/social/callback/google", state)
|
|
case "github":
|
|
redirectURL = fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&scope=user:email&state=%s",
|
|
authURL, cfg.OIDCClientID, "http://localhost:17000/api/v1/auth/social/callback/github", state)
|
|
}
|
|
|
|
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
|
}
|
|
|
|
// SocialCallbackHandler handles social media login callback
|
|
func SocialCallbackHandler(c *gin.Context) {
|
|
provider := c.Param("provider")
|
|
code := c.Query("code")
|
|
state := c.Query("state")
|
|
|
|
if code == "" || state == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing code or state parameter"})
|
|
return
|
|
}
|
|
|
|
// Verify state parameter
|
|
ctx := context.Background()
|
|
storedState, err := redisClient.Get(ctx, fmt.Sprintf("social_state:%s:%s", provider, state)).Result()
|
|
if err != nil || storedState != "valid" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid or expired state parameter"})
|
|
return
|
|
}
|
|
|
|
// Remove the state from Redis (one-time use)
|
|
redisClient.Del(ctx, fmt.Sprintf("social_state:%s:%s", provider, state))
|
|
|
|
// Get user info from social provider
|
|
userInfo, err := getUserInfoFromProvider(provider, code)
|
|
if err != nil {
|
|
log.Printf("Failed to get user info from %s: %v", provider, err)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Failed to get user info from %s", provider)})
|
|
return
|
|
}
|
|
|
|
// Check if user exists with this social identity
|
|
var socialIdentity models.SocialIdentity
|
|
result := db.DB.Where("provider_name = ? AND provider_user_id = ?", provider, userInfo.ProviderUserID).First(&socialIdentity)
|
|
|
|
var user models.User
|
|
if result.Error != nil {
|
|
// User doesn't exist, create new user
|
|
user = models.User{
|
|
Email: userInfo.Email,
|
|
FirstName: userInfo.FirstName,
|
|
LastName: userInfo.LastName,
|
|
EmailVerified: userInfo.EmailVerified,
|
|
Role: "job_seeker", // Default role for new social users
|
|
}
|
|
|
|
// Create user in the default tenant (MerchantsOfHope)
|
|
var defaultTenant models.Tenant
|
|
if err := db.DB.Where("slug = ?", "merchants-of-hope").First(&defaultTenant).Error; err != nil {
|
|
log.Printf("Failed to get default tenant: %v", err)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get default tenant"})
|
|
return
|
|
}
|
|
user.TenantID = defaultTenant.ID
|
|
|
|
if err := db.DB.Create(&user).Error; err != nil {
|
|
log.Printf("Failed to create user: %v", err)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
|
|
return
|
|
}
|
|
|
|
// Create social identity
|
|
socialIdentity = models.SocialIdentity{
|
|
UserID: user.ID,
|
|
ProviderName: provider,
|
|
ProviderUserID: userInfo.ProviderUserID,
|
|
AccessToken: userInfo.AccessToken,
|
|
RefreshToken: userInfo.RefreshToken,
|
|
ExpiresAt: userInfo.ExpiresAt,
|
|
ProfileData: userInfo.ProfileData,
|
|
}
|
|
if err := db.DB.Create(&socialIdentity).Error; err != nil {
|
|
log.Printf("Failed to create social identity: %v", err)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create social identity"})
|
|
return
|
|
}
|
|
} else {
|
|
// User exists, update user info if needed
|
|
if err := db.DB.First(&user, socialIdentity.UserID).Error; err != nil {
|
|
log.Printf("Failed to find user for social identity: %v", err)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to find user for social identity"})
|
|
return
|
|
}
|
|
|
|
// Update user info if it has changed
|
|
updateNeeded := false
|
|
if user.Email != userInfo.Email {
|
|
user.Email = userInfo.Email
|
|
updateNeeded = true
|
|
}
|
|
|
|
if user.FirstName != userInfo.FirstName {
|
|
user.FirstName = userInfo.FirstName
|
|
updateNeeded = true
|
|
}
|
|
|
|
if user.LastName != userInfo.LastName {
|
|
user.LastName = userInfo.LastName
|
|
updateNeeded = true
|
|
}
|
|
|
|
if updateNeeded {
|
|
if err := db.DB.Save(&user).Error; err != nil {
|
|
log.Printf("Failed to update user: %v", err)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update user"})
|
|
return
|
|
}
|
|
}
|
|
|
|
// Update social identity
|
|
socialIdentity.AccessToken = userInfo.AccessToken
|
|
socialIdentity.RefreshToken = userInfo.RefreshToken
|
|
socialIdentity.ExpiresAt = userInfo.ExpiresAt
|
|
socialIdentity.ProfileData = userInfo.ProfileData
|
|
if err := db.DB.Save(&socialIdentity).Error; err != nil {
|
|
log.Printf("Failed to update social identity: %v", err)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update social identity"})
|
|
return
|
|
}
|
|
}
|
|
|
|
// Generate JWT token for the user
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
|
"user_id": user.ID.String(),
|
|
"email": user.Email,
|
|
"role": user.Role,
|
|
"exp": time.Now().Add(time.Hour * 24).Unix(), // Token expires in 24 hours
|
|
})
|
|
|
|
tokenString, err := token.SignedString([]byte(cfg.JWTSecret))
|
|
if err != nil {
|
|
log.Printf("Failed to generate JWT token: %v", err)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate JWT token"})
|
|
return
|
|
}
|
|
|
|
// Return token to client
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"token": tokenString,
|
|
"user": user,
|
|
"method": "social",
|
|
"provider": provider,
|
|
})
|
|
}
|
|
|
|
// SocialUserInfo represents the user information returned from social providers
|
|
type SocialUserInfo struct {
|
|
ProviderUserID string
|
|
Email string
|
|
FirstName string
|
|
LastName string
|
|
EmailVerified bool
|
|
AccessToken string
|
|
RefreshToken string
|
|
ExpiresAt *time.Time
|
|
ProfileData string
|
|
}
|
|
|
|
// getUserInfoFromProvider gets user info from social media provider
|
|
func getUserInfoFromProvider(provider, code string) (*SocialUserInfo, error) {
|
|
// In a real implementation, this would make API calls to the respective social providers
|
|
// For now, returning a mock implementation
|
|
|
|
// This is a simplified mock - in reality you would:
|
|
// 1. Exchange the code for an access token
|
|
// 2. Use the access token to get user profile information
|
|
// 3. Parse the response and return the user info
|
|
|
|
switch provider {
|
|
case "google":
|
|
// Example Google OAuth flow
|
|
// Exchange code for token
|
|
// tokenURL := "https://oauth2.googleapis.com/token"
|
|
// ... perform token exchange ...
|
|
|
|
// Get user info
|
|
// userInfoURL := "https://www.googleapis.com/oauth2/v2/userinfo"
|
|
// ... perform user info request ...
|
|
|
|
// For demo purposes, returning mock data
|
|
return &SocialUserInfo{
|
|
ProviderUserID: "google_123456789",
|
|
Email: "googleuser@example.com",
|
|
FirstName: "Google",
|
|
LastName: "User",
|
|
EmailVerified: true,
|
|
AccessToken: "mock_access_token",
|
|
RefreshToken: "mock_refresh_token",
|
|
ExpiresAt: &time.Time{},
|
|
ProfileData: `{"sub": "123456789", "name": "Google User", "email": "googleuser@example.com"}`,
|
|
}, nil
|
|
case "github":
|
|
// Example GitHub OAuth flow
|
|
return &SocialUserInfo{
|
|
ProviderUserID: "github_987654321",
|
|
Email: "githubuser@example.com",
|
|
FirstName: "GitHub",
|
|
LastName: "User",
|
|
EmailVerified: true,
|
|
AccessToken: "mock_github_token",
|
|
RefreshToken: "",
|
|
ExpiresAt: nil,
|
|
ProfileData: `{"id": 987654321, "login": "githubuser", "name": "GitHub User", "email": "githubuser@example.com"}`,
|
|
}, nil
|
|
default:
|
|
return nil, fmt.Errorf("unsupported provider: %s", provider)
|
|
}
|
|
}
|
|
|
|
// generateRandomState generates a random state parameter for OIDC
|
|
func generateRandomState() string {
|
|
b := make([]byte, 32)
|
|
for i := range b {
|
|
b[i] = byte('a' + (i % 26))
|
|
}
|
|
return string(b)
|
|
} |