package oauth2

import (
	"context"
	"crypto/sha256"
	"encoding/hex"
	"errors"
	"fmt"
	"log/slog"
	"net/http"
	"strconv"
	"time"

	"github.com/jkroepke/openvpn-auth-oauth2/internal/oauth2/idtoken"
	"github.com/jkroepke/openvpn-auth-oauth2/internal/oauth2/types"
	"github.com/jkroepke/openvpn-auth-oauth2/internal/state"
	"github.com/zitadel/logging"
	"github.com/zitadel/oidc/v3/pkg/client/rp"
)

type openvpnManagementClient interface {
	AcceptClient(ctx context.Context, logger *slog.Logger, client state.ClientIdentifier, reAuth bool, username string)
	DenyClient(ctx context.Context, logger *slog.Logger, client state.ClientIdentifier, reason string)
}

// OAuth2Start returns a http.Handler that starts the OAuth2 authorization flow.
// It checks if the request has a valid state GET parameter generated by state.New.
// Optionally, it checks the HTTP client IP address against the VPN IP address.
// After the checks, the request is delegated to [rp.AuthURLHandler].
func (c Client) OAuth2Start() http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		ctx := r.Context()

		// check if request has a state GET parameter generated state.New.
		sessionState := r.URL.Query().Get("state")
		if sessionState == "" {
			w.WriteHeader(http.StatusBadRequest)

			return
		}

		// decode the state GET parameter
		session, err := state.NewWithEncodedToken(sessionState, c.conf.HTTP.Secret.String())
		if err != nil {
			c.logger.LogAttrs(ctx, slog.LevelWarn, "invalid state: "+err.Error())
			w.WriteHeader(http.StatusBadRequest)

			return
		}

		logger := c.logger.With(
			slog.String("ip", fmt.Sprintf("%s:%s", session.IPAddr, session.IPPort)),
			slog.Uint64("cid", session.Client.CID),
			slog.Uint64("kid", session.Client.KID),
			slog.String("common_name", session.Client.CommonName),
		)

		if c.conf.HTTP.Check.IPAddr {
			if err := checkClientIPAddr(r, c.conf, session); err != nil {
				logger.LogAttrs(ctx, slog.LevelWarn, err.Error())

				if !errors.Is(err, ErrClientRejected) {
					c.openvpn.DenyClient(ctx, logger, session.Client, "client rejected")
					w.WriteHeader(http.StatusInternalServerError)

					return
				}

				c.openvpn.DenyClient(ctx, logger, session.Client, err.Error())
				w.WriteHeader(http.StatusForbidden)

				return
			}
		}

		logger.LogAttrs(ctx, slog.LevelInfo, "initialize authorization via oauth2")

		authorizeParams := c.authorizeParams

		if c.conf.OAuth2.Nonce {
			id := strconv.FormatUint(session.Client.CID, 10)
			if c.conf.OAuth2.Refresh.UseSessionID && session.Client.SessionID != "" {
				id = session.Client.SessionID
			}

			authorizeParams = append(authorizeParams, rp.WithURLParam("nonce", c.getNonce(id)))
		}

		rp.AuthURLHandler(func() string {
			return sessionState
		}, c.relyingParty, authorizeParams...).ServeHTTP(w, r)
	})
}

// OAuth2Callback returns a http.Handler that handles the OAuth2 callback.
func (c Client) OAuth2Callback() http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
		defer cancel()

		encryptedState := r.URL.Query().Get("state")
		if encryptedState == "" {
			c.writeHTTPError(ctx, w, c.logger, http.StatusBadRequest, "Bad Request", "state is empty")

			return
		}

		session, err := state.NewWithEncodedToken(encryptedState, c.conf.HTTP.Secret.String())
		if err != nil {
			c.writeHTTPError(ctx, w, c.logger, http.StatusBadRequest, "Invalid State", err.Error())

			return
		}

		logger := c.logger.With(
			slog.String("ip", fmt.Sprintf("%s:%s", session.IPAddr, session.IPPort)),
			slog.Uint64("cid", session.Client.CID),
			slog.Uint64("kid", session.Client.KID),
			slog.String("common_name", session.Client.CommonName),
			slog.String("session_id", session.Client.SessionID),
			slog.String("session_state", session.SessionState),
		)

		ctx = logging.ToContext(ctx, logger)

		clientID := strconv.FormatUint(session.Client.CID, 10)
		if c.conf.OAuth2.Refresh.UseSessionID && session.Client.SessionID != "" {
			clientID = session.Client.SessionID
		}

		if c.conf.OAuth2.Nonce {
			ctx = context.WithValue(ctx, types.CtxNonce{}, c.getNonce(clientID))
			r = r.WithContext(ctx)
		}

		codeExchangeHandler := rp.CodeExchangeCallback[*idtoken.Claims](func(
			w http.ResponseWriter, r *http.Request,
			tokens idtoken.IDToken, state string, provider rp.RelyingParty,
		) {
			c.postCodeExchangeHandler(logger, session, clientID)(w, r, tokens, state, provider, nil)
		})

		if c.conf.OAuth2.UserInfo {
			codeExchangeHandler = rp.UserinfoCallback(c.postCodeExchangeHandler(logger, session, clientID))
		}

		rp.CodeExchangeHandler(
			codeExchangeHandler,
			c.relyingParty,
		).ServeHTTP(w, r)
	})
}

func (c Client) postCodeExchangeHandler(
	logger *slog.Logger, session state.State, clientID string,
) rp.CodeExchangeUserinfoCallback[*idtoken.Claims, *types.UserInfo] {
	return func(
		w http.ResponseWriter, r *http.Request, tokens idtoken.IDToken, _ string,
		_ rp.RelyingParty, userInfo *types.UserInfo,
	) {
		ctx := r.Context()

		if tokens.IDTokenClaims != nil {
			logger = logger.With(
				slog.String("idtoken_subject", tokens.IDTokenClaims.Subject),
				slog.String("idtoken_email", tokens.IDTokenClaims.EMail),
				slog.String("idtoken_preferred_username", tokens.IDTokenClaims.PreferredUsername),
			)

			logger.LogAttrs(ctx, slog.LevelDebug, "claims", slog.Any("claims", tokens.IDTokenClaims.Claims))
		}

		user, err := c.provider.GetUser(ctx, logger, tokens, userInfo)
		if err != nil {
			c.openvpn.DenyClient(ctx, logger, session.Client, "unable to fetch user data")
			c.writeHTTPError(ctx, w, logger, http.StatusInternalServerError, "fetchUser", err.Error())

			return
		}

		logger = logger.With(
			slog.String("user_subject", user.Subject),
			slog.String("user_preferred_username", user.PreferredUsername),
		)

		if err = c.provider.CheckUser(ctx, session, user, tokens); err != nil {
			c.openvpn.DenyClient(ctx, logger, session.Client, "client rejected")
			c.writeHTTPError(ctx, w, logger, http.StatusForbidden, "user validation", err.Error())

			return
		}

		logger.LogAttrs(ctx, slog.LevelInfo, "successful authorization via oauth2")

		username := user.PreferredUsername
		if username == "" {
			username = session.Client.CommonName
		}

		c.openvpn.AcceptClient(ctx, logger, session.Client, false, username)
		c.postCodeExchangeHandlerStoreRefreshToken(ctx, logger, session, clientID, tokens)
		c.writeHTTPSuccess(ctx, w, logger)
	}
}

func (c Client) postCodeExchangeHandlerStoreRefreshToken(
	ctx context.Context, logger *slog.Logger, session state.State, clientID string, tokens idtoken.IDToken,
) {
	if !c.conf.OAuth2.Refresh.Enabled {
		return
	}

	if !c.conf.OAuth2.Refresh.ValidateUser {
		if err := c.storage.Set(clientID, types.EmptyToken); err != nil {
			logger.LogAttrs(ctx, slog.LevelWarn, err.Error())
		} else {
			logger.LogAttrs(ctx, slog.LevelDebug, "empty token for non-interactive re-authentication stored")
		}

		return
	}

	refreshToken, err := c.provider.GetRefreshToken(tokens)
	if err != nil {
		logLevel := slog.LevelWarn

		if errors.Is(err, ErrNoRefreshToken) {
			if session.SessionState == "AuthenticatedEmptyUser" || session.SessionState == "Authenticated" {
				logLevel = slog.LevelDebug
			}
		}

		logger.LogAttrs(ctx, logLevel, fmt.Errorf("oauth2.refresh is enabled, but %w", err).Error())
	}

	if refreshToken == "" {
		logger.LogAttrs(ctx, slog.LevelWarn, "refresh token is empty")
	} else if err = c.storage.Set(clientID, refreshToken); err != nil {
		logger.LogAttrs(ctx, slog.LevelWarn, "unable to store refresh token",
			slog.Any("err", err),
		)
	} else {
		logger.LogAttrs(ctx, slog.LevelDebug, "refresh token for non-interactive re-authentication stored")
	}
}

func (c Client) httpErrorHandler(ctx context.Context, w http.ResponseWriter, httpStatus int, errorType, errorDesc, encryptedSession string) {
	logger := c.logger

	session, err := state.NewWithEncodedToken(encryptedSession, c.conf.HTTP.Secret.String())
	if err == nil {
		logger = c.logger.With(
			slog.String("ip", fmt.Sprintf("%s:%s", session.IPAddr, session.IPPort)),
			slog.Uint64("cid", session.Client.CID),
			slog.Uint64("kid", session.Client.KID),
			slog.String("common_name", session.Client.CommonName),
		)

		c.openvpn.DenyClient(ctx, logger, session.Client, "client rejected")
	} else {
		logger.LogAttrs(ctx, slog.LevelDebug, "httpErrorHandler: "+err.Error())
	}

	c.writeHTTPError(ctx, w, logger, httpStatus, errorType, errorDesc)
}

func (c Client) writeHTTPError(ctx context.Context, w http.ResponseWriter, logger *slog.Logger, httpCode int, errorType, errorDesc string) {
	if httpCode == http.StatusUnauthorized {
		httpCode = http.StatusForbidden
	}

	h := sha256.New()
	h.Write([]byte(time.Now().String()))

	errorID := hex.EncodeToString(h.Sum(nil))

	logger.LogAttrs(ctx, slog.LevelWarn, fmt.Sprintf("%s: %s", errorType, errorDesc), slog.String("error_id", errorID))
	w.WriteHeader(httpCode)

	err := c.conf.HTTP.Template.Execute(w, map[string]string{
		"title":   "Access denied",
		"message": "Please contact your administrator.",
		"errorID": errorID,
	})
	if err != nil {
		logger.LogAttrs(ctx, slog.LevelError, fmt.Errorf("executing template: %w", err).Error())
		w.WriteHeader(http.StatusInternalServerError)
	}
}

func (c Client) writeHTTPSuccess(ctx context.Context, w http.ResponseWriter, logger *slog.Logger) {
	err := c.conf.HTTP.Template.Execute(w, map[string]string{
		"title":   "Access granted",
		"message": "You can close this window now.",
		"errorID": "",
	})
	if err != nil {
		logger.LogAttrs(ctx, slog.LevelError, "template error", slog.Any(
			"err", fmt.Errorf("executing template: %w", err),
		))
		w.WriteHeader(http.StatusInternalServerError)
	}
}
