package middleware

import (
	"crypto/subtle"
	"errors"

	"github.com/dave-gray101/v2keyauth"
	"github.com/gofiber/fiber/v2"
	"github.com/gofiber/fiber/v2/middleware/keyauth"
	"github.com/mudler/LocalAI/core/config"
)

// This file contains the configuration generators and handler functions that are used along with the fiber/keyauth middleware
// Currently this requires an upstream patch - and feature patches are no longer accepted to v2
// Therefore `dave-gray101/v2keyauth` contains the v2 backport of the middleware until v3 stabilizes and we migrate.

func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (*v2keyauth.Config, error) {
	customLookup, err := v2keyauth.MultipleKeySourceLookup([]string{"header:Authorization", "header:x-api-key", "header:xi-api-key"}, keyauth.ConfigDefault.AuthScheme)
	if err != nil {
		return nil, err
	}

	return &v2keyauth.Config{
		CustomKeyLookup: customLookup,
		Next:            getApiKeyRequiredFilterFunction(applicationConfig),
		Validator:       getApiKeyValidationFunction(applicationConfig),
		ErrorHandler:    getApiKeyErrorHandler(applicationConfig),
		AuthScheme:      "Bearer",
	}, nil
}

func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.ErrorHandler {
	return func(ctx *fiber.Ctx, err error) error {
		if errors.Is(err, v2keyauth.ErrMissingOrMalformedAPIKey) {
			if len(applicationConfig.ApiKeys) == 0 {
				return ctx.Next() // if no keys are set up, any error we get here is not an error.
			}
			if applicationConfig.OpaqueErrors {
				return ctx.SendStatus(403)
			}
			return ctx.Status(403).SendString(err.Error())
		}
		if applicationConfig.OpaqueErrors {
			return ctx.SendStatus(500)
		}
		return err
	}
}

func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx, string) (bool, error) {

	if applicationConfig.UseSubtleKeyComparison {
		return func(ctx *fiber.Ctx, apiKey string) (bool, error) {
			if len(applicationConfig.ApiKeys) == 0 {
				return true, nil // If no keys are setup, accept everything
			}
			for _, validKey := range applicationConfig.ApiKeys {
				if subtle.ConstantTimeCompare([]byte(apiKey), []byte(validKey)) == 1 {
					return true, nil
				}
			}
			return false, v2keyauth.ErrMissingOrMalformedAPIKey
		}
	}

	return func(ctx *fiber.Ctx, apiKey string) (bool, error) {
		if len(applicationConfig.ApiKeys) == 0 {
			return true, nil // If no keys are setup, accept everything
		}
		for _, validKey := range applicationConfig.ApiKeys {
			if apiKey == validKey {
				return true, nil
			}
		}
		return false, v2keyauth.ErrMissingOrMalformedAPIKey
	}
}

func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx) bool {
	if applicationConfig.DisableApiKeyRequirementForHttpGet {
		return func(c *fiber.Ctx) bool {
			if c.Method() != "GET" {
				return false
			}
			for _, rx := range applicationConfig.HttpGetExemptedEndpoints {
				if rx.MatchString(c.Path()) {
					return true
				}
			}
			return false
		}
	}
	return func(c *fiber.Ctx) bool { return false }
}