package api

import (
	"crypto/rand"
	"crypto/subtle"
	"encoding/json"
	"log"
	"net/http"
	"time"

	"github.com/dgrijalva/jwt-go"
	"golang.org/x/crypto/argon2"
)

type creds struct {
	Name     string `json:"name"`
	Password string `json:"password"`
}

func (a *api) SignIn(w http.ResponseWriter, req *http.Request) {
	var c creds
	err := json.NewDecoder(req.Body).Decode(&c)
	if err != nil {
		log.Printf("Can't decode auth credentials: %v", err)
		w.WriteHeader(http.StatusInternalServerError)
		return
	}
	var member Member
	err = a.db.Where("name = ?", c.Name).First(&member).Error
	if err != nil {
		log.Printf("Can't locate user %s: %v", c.Name, err)
		w.WriteHeader(http.StatusBadRequest)
		return
	}

	hash := hashPass(c.Password, member.Salt)
	if subtle.ConstantTimeCompare(hash, member.PassHash) == 0 {
		log.Printf("Invalid pass for %s", c.Name)
		w.WriteHeader(http.StatusBadRequest)
		return
	}

	log.Printf("Logged in as %s", c.Name)
	w.Header().Set("Content-Type", "application/json")
	w.WriteHeader(http.StatusOK)

	token, err := a.newToken(member.Num, member.Role)
	if err != nil {
		log.Printf("Can't create a token: %v", err)
		w.WriteHeader(http.StatusInternalServerError)
		return
	}
	err = json.NewEncoder(w).Encode(map[string]interface{}{
		"token":  token,
		"member": member})
	if err != nil {
		log.Printf("Can't encode member: %v", err)
		w.WriteHeader(http.StatusInternalServerError)
	}
}

func (a *api) GetToken(w http.ResponseWriter, req *http.Request) {
	token := req.Header.Get("x-authentication")
	ok, claims := a.validateToken(token)
	if !ok {
		w.WriteHeader(http.StatusUnauthorized)
		return
	}
	num, ok := claims["num"].(float64)
	if !ok {
		w.WriteHeader(http.StatusUnauthorized)
		return
	}
	role, ok := claims["role"].(string)
	if !ok {
		w.WriteHeader(http.StatusUnauthorized)
		return
	}

	w.Header().Set("Content-Type", "application/json")
	w.WriteHeader(http.StatusOK)

	token, err := a.newToken(int(num), role)
	if err != nil {
		log.Printf("Can't create a token: %v", err)
		w.WriteHeader(http.StatusInternalServerError)
		return
	}
	err = json.NewEncoder(w).Encode(map[string]interface{}{
		"token": token,
	})
	if err != nil {
		log.Printf("Can't encode token: %v", err)
		w.WriteHeader(http.StatusInternalServerError)
	}
}

func (a *api) auth(fn func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
	return func(w http.ResponseWriter, req *http.Request) {
		token := req.Header.Get("x-authentication")
		if ok, _ := a.validateToken(token); !ok {
			w.WriteHeader(http.StatusUnauthorized)
			return
		}
		fn(w, req)
	}
}

func (a *api) authNum(fn func(int, http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
	return func(w http.ResponseWriter, req *http.Request) {
		token := req.Header.Get("x-authentication")
		ok, claims := a.validateToken(token)
		if !ok {
			w.WriteHeader(http.StatusUnauthorized)
			return
		}
		num, ok := claims["num"].(float64)
		if !ok {
			w.WriteHeader(http.StatusUnauthorized)
			return
		}
		fn(int(num), w, req)
	}
}

func (a *api) authAdmin(fn func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
	return func(w http.ResponseWriter, req *http.Request) {
		token := req.Header.Get("x-authentication")
		ok, claims := a.validateToken(token)
		if !ok {
			w.WriteHeader(http.StatusUnauthorized)
			return
		}
		role, ok := claims["role"].(string)
		if !ok || role != "admin" {
			w.WriteHeader(http.StatusUnauthorized)
			return
		}
		fn(w, req)
	}
}

func (a *api) authNumRole(fn func(int, string, http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
	return func(w http.ResponseWriter, req *http.Request) {
		token := req.Header.Get("x-authentication")
		ok, claims := a.validateToken(token)
		if !ok {
			w.WriteHeader(http.StatusUnauthorized)
			return
		}
		num, ok := claims["num"].(float64)
		if !ok {
			w.WriteHeader(http.StatusUnauthorized)
			return
		}
		role, ok := claims["role"].(string)
		if !ok {
			w.WriteHeader(http.StatusUnauthorized)
			return
		}
		fn(int(num), role, w, req)
	}
}

func (a *api) newToken(num int, role string) (string, error) {
	token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
		"num":  num,
		"role": role,
		"exp":  time.Now().Add(time.Hour * 24).Unix(),
	})
	return token.SignedString(a.signKey)
}

func (a *api) validateToken(token string) (bool, jwt.MapClaims) {
	t, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) {
		return a.signKey, nil
	})
	if err != nil {
		return false, nil
	}
	if !t.Valid {
		return false, nil
	}
	claims, ok := t.Claims.(jwt.MapClaims)
	if !ok {
		return false, nil
	}
	exp, ok := claims["exp"].(float64)
	if !ok {
		return false, claims
	}
	return time.Unix(int64(exp), 0).After(time.Now()), claims
}

func newHashPass(password string) (hash []byte, salt []byte, err error) {
	salt = make([]byte, 32)
	_, err = rand.Read(salt)
	if err != nil {
		return
	}

	hash = hashPass(password, salt)
	return
}

func hashPass(password string, salt []byte) []byte {
	const (
		time    = 1
		memory  = 64 * 1024
		threads = 2
		keyLen  = 32
	)

	return argon2.IDKey([]byte(password), salt, time, memory, threads, keyLen)
}