From a9ae96b5698d6e984cda94096ba49e52fe470179 Mon Sep 17 00:00:00 2001
From: meskio <meskio@sindominio.net>
Date: Fri, 18 Sep 2020 19:34:05 +0200
Subject: [PATCH] Add authentication in the API side

---
 api/api.go       |  18 ++++---
 api/api_test.go  |  22 ++++++++-
 api/auth.go      | 124 +++++++++++++++++++++++++++++++++++++++++++++++
 api/auth_test.go |  50 +++++++++++++++++++
 api/member.go    |  26 ++++++++--
 go.mod           |   2 +
 go.sum           |   5 ++
 main.go          |   7 +--
 8 files changed, 237 insertions(+), 17 deletions(-)
 create mode 100644 api/auth.go
 create mode 100644 api/auth_test.go

diff --git a/api/api.go b/api/api.go
index 4c83095..5191679 100644
--- a/api/api.go
+++ b/api/api.go
@@ -7,7 +7,8 @@ import (
 )
 
 type api struct {
-	db *gorm.DB
+	db      *gorm.DB
+	signKey []byte
 }
 
 func initDB(dbPath string) (*gorm.DB, error) {
@@ -20,17 +21,18 @@ func initDB(dbPath string) (*gorm.DB, error) {
 	return db, err
 }
 
-func Init(dbPath string, r *mux.Router) error {
+func Init(dbPath string, signKey string, r *mux.Router) error {
 	db, err := initDB(dbPath)
 	if err != nil {
 		return err
 	}
 
-	a := api{db}
-	r.HandleFunc("/member", a.ListMembers).Methods("GET")
-	r.HandleFunc("/member", a.AddMember).Methods("POST")
-	r.HandleFunc("/member/{num:[0-9]+}", a.GetMember).Methods("GET")
-	r.HandleFunc("/member/{num:[0-9]+}", a.UpdateMember).Methods("PUT")
-	r.HandleFunc("/member/{num:[0-9]+}", a.DeleteMember).Methods("DELETE")
+	a := api{db, []byte(signKey)}
+	r.HandleFunc("/signin", a.SignIn).Methods("POST")
+	r.HandleFunc("/member", a.auth(a.ListMembers)).Methods("GET")
+	r.HandleFunc("/member", a.auth(a.AddMember)).Methods("POST")
+	r.HandleFunc("/member/{num:[0-9]+}", a.auth(a.GetMember)).Methods("GET")
+	r.HandleFunc("/member/{num:[0-9]+}", a.auth(a.UpdateMember)).Methods("PUT")
+	r.HandleFunc("/member/{num:[0-9]+}", a.auth(a.DeleteMember)).Methods("DELETE")
 	return nil
 }
diff --git a/api/api_test.go b/api/api_test.go
index 06df2f7..83e4599 100644
--- a/api/api_test.go
+++ b/api/api_test.go
@@ -9,10 +9,16 @@ import (
 	"os"
 	"path"
 	"testing"
+	"time"
 
+	"github.com/dgrijalva/jwt-go"
 	"github.com/gorilla/mux"
 )
 
+const (
+	signKey = "secret"
+)
+
 func TestInit(t *testing.T) {
 	tapi := newTestAPI(t)
 	defer tapi.close()
@@ -29,6 +35,7 @@ type testAPI struct {
 	client   *http.Client
 	server   *httptest.Server
 	testPath string
+	token    string
 }
 
 func newTestAPI(t *testing.T) *testAPI {
@@ -39,13 +46,23 @@ func newTestAPI(t *testing.T) *testAPI {
 	dbPath := path.Join(testPath, "test.db")
 
 	r := mux.NewRouter()
-	err = Init(dbPath, r)
+	err = Init(dbPath, signKey, r)
 	if err != nil {
 		t.Fatal("Init error:", err)
 	}
 	server := httptest.NewServer(r)
 
-	return &testAPI{t, server.URL, &http.Client{}, server, testPath}
+	token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
+		"num":  0,
+		"role": "admin",
+		"exp":  time.Now().Add(time.Hour * 24).Unix(),
+	})
+	tokenString, err := token.SignedString([]byte(signKey))
+	if err != nil {
+		t.Fatal("Can't generate token:", err)
+	}
+
+	return &testAPI{t, server.URL, &http.Client{}, server, testPath, tokenString}
 }
 
 func (ta *testAPI) do(method string, url string, body interface{}, respBody interface{}) *http.Response {
@@ -62,6 +79,7 @@ func (ta *testAPI) do(method string, url string, body interface{}, respBody inte
 	if err != nil {
 		ta.t.Fatal("Can't build request", method, url, err)
 	}
+	req.Header.Add("x-authentication", ta.token)
 	resp, err := ta.client.Do(req)
 	if err != nil {
 		ta.t.Fatal("HTTP query failed", method, url, err)
diff --git a/api/auth.go b/api/auth.go
new file mode 100644
index 0000000..11cdf58
--- /dev/null
+++ b/api/auth.go
@@ -0,0 +1,124 @@
+package api
+
+import (
+	"crypto/rand"
+	"crypto/subtle"
+	"encoding/json"
+	"log"
+	"net/http"
+	"time"
+
+	"github.com/dgrijalva/jwt-go"
+	"golang.org/x/crypto/argon2"
+)
+
+type creds struct {
+	Name     string `json:"name"`
+	Password string `json:"password"`
+}
+
+func (a *api) SignIn(w http.ResponseWriter, req *http.Request) {
+	var c creds
+	err := json.NewDecoder(req.Body).Decode(&c)
+	if err != nil {
+		log.Printf("Can't decode auth credentials: %v", err)
+		w.WriteHeader(http.StatusInternalServerError)
+		return
+	}
+	var member Member
+	err = a.db.Where("name = ?", c.Name).First(&member).Error
+	if err != nil {
+		log.Printf("Can't locate user %s: %v", c.Name, err)
+		w.WriteHeader(http.StatusBadRequest)
+		return
+	}
+
+	hash := hashPass(c.Password, member.Salt)
+	if subtle.ConstantTimeCompare(hash, member.PassHash) == 0 {
+		log.Printf("Invalid pass for %s", c.Name)
+		w.WriteHeader(http.StatusBadRequest)
+		return
+	}
+
+	log.Printf("Logged in as %s", c.Name)
+	w.Header().Set("Content-Type", "application/json")
+	w.WriteHeader(http.StatusOK)
+
+	token, err := a.newToken(member.Num, member.Role)
+	if err != nil {
+		log.Printf("Can't create a token: %v", err)
+		w.WriteHeader(http.StatusInternalServerError)
+		return
+	}
+	err = json.NewEncoder(w).Encode(map[string]interface{}{
+		"token":  token,
+		"member": member})
+	if err != nil {
+		log.Printf("Can't encode member: %v", err)
+		w.WriteHeader(http.StatusInternalServerError)
+	}
+}
+
+func (a *api) auth(fn func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
+	return func(w http.ResponseWriter, req *http.Request) {
+		token := req.Header.Get("x-authentication")
+		if !a.validToken(token) {
+			w.WriteHeader(http.StatusUnauthorized)
+			return
+		}
+		fn(w, req)
+	}
+}
+
+func (a *api) newToken(num int, role string) (string, error) {
+	token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
+		"num":  num,
+		"role": role,
+		"exp":  time.Now().Add(time.Hour * 24).Unix(),
+	})
+	return token.SignedString(a.signKey)
+}
+
+func (a *api) validToken(token string) bool {
+	t, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) {
+		return a.signKey, nil
+	})
+	if err != nil {
+		return false
+	}
+	if !t.Valid {
+		return false
+	}
+	claims, ok := t.Claims.(jwt.MapClaims)
+	if !ok {
+		return false
+	}
+	exp, ok := claims["exp"].(float64)
+	if !ok {
+		return false
+	}
+	// TODO: num, role
+	return time.Unix(int64(exp), 0).After(time.Now())
+}
+
+func newHashPass(password string) (hash []byte, salt []byte, err error) {
+	salt = make([]byte, 32)
+	_, err = rand.Read(salt)
+	if err != nil {
+		return
+	}
+
+	hash = hashPass(password, salt)
+	return
+}
+
+func hashPass(password string, salt []byte) []byte {
+	const (
+		time    = 1
+		memory  = 64 * 1024
+		threads = 2
+		keyLen  = 32
+	)
+
+	return argon2.IDKey([]byte(password), salt, time, memory, threads, keyLen)
+}
diff --git a/api/auth_test.go b/api/auth_test.go
new file mode 100644
index 0000000..534937f
--- /dev/null
+++ b/api/auth_test.go
@@ -0,0 +1,50 @@
+package api
+
+import (
+	"net/http"
+	"testing"
+)
+
+func TestSignIn(t *testing.T) {
+	tapi := newTestAPI(t)
+	defer tapi.close()
+
+	var member struct {
+		Member
+		Password string `json:"password"`
+	}
+	member.Num = 10
+	member.Name = "foo"
+	member.Password = "password"
+	resp := tapi.do("POST", "/member", member, nil)
+	if resp.StatusCode != http.StatusCreated {
+		t.Fatal("Can't create member:", resp.Status)
+	}
+
+	tapi.token = ""
+	resp = tapi.do("GET", "/member", nil, nil)
+	if resp.StatusCode != http.StatusUnauthorized {
+		t.Error("Got members without auth")
+	}
+
+	var respMember struct {
+		Token  string `json:"token"`
+		Member Member `json:"member"`
+	}
+	jsonAuth := creds{
+		Name:     member.Name,
+		Password: member.Password,
+	}
+	resp = tapi.do("POST", "/signin", jsonAuth, &respMember)
+	if resp.StatusCode != http.StatusOK {
+		t.Fatal("Can't sign in:", resp.Status)
+	}
+	if respMember.Member.Name != member.Name {
+		t.Fatal("Unexpected member:", respMember)
+	}
+	tapi.token = respMember.Token
+	resp = tapi.do("GET", "/member", nil, nil)
+	if resp.StatusCode != http.StatusOK {
+		t.Fatal("Can't get members:", resp.Status)
+	}
+}
diff --git a/api/member.go b/api/member.go
index 2f51a4b..2408b33 100644
--- a/api/member.go
+++ b/api/member.go
@@ -11,21 +11,39 @@ import (
 
 type Member struct {
 	gorm.Model `json:"-"`
-	Num        int    `json:"num"`
-	Name       string `json:"name"`
+	Num        int    `json:"num",gorm:"unique;index"`
+	Name       string `json:"name",gorm:"unique;index"`
 	Email      string `json:"email"`
 	Balance    int    `json:"balance"`
 	Role       string `json:"role"`
+	PassHash   []byte `json:"-"`
+	Salt       []byte `json:"-"`
 }
 
 func (a *api) AddMember(w http.ResponseWriter, req *http.Request) {
-	var member Member
-	err := json.NewDecoder(req.Body).Decode(&member)
+	var memberReq struct {
+		Member
+		Password string `json:"password"`
+	}
+	err := json.NewDecoder(req.Body).Decode(&memberReq)
 	if err != nil {
 		log.Printf("Can't create member: %v", err)
 		w.WriteHeader(http.StatusInternalServerError)
 		return
 	}
+	member := Member{
+		Num:     memberReq.Num,
+		Name:    memberReq.Name,
+		Email:   memberReq.Email,
+		Balance: memberReq.Balance,
+		Role:    memberReq.Role,
+	}
+	member.PassHash, member.Salt, err = newHashPass(memberReq.Password)
+	if err != nil {
+		log.Printf("Can't hash new member: %v\n%v", err, member)
+		w.WriteHeader(http.StatusInternalServerError)
+		return
+	}
 	err = a.db.Create(&member).Error
 	if err != nil {
 		log.Printf("Can't create member: %v\n%v", err, member)
diff --git a/go.mod b/go.mod
index 90aab06..a2d2f81 100644
--- a/go.mod
+++ b/go.mod
@@ -3,8 +3,10 @@ module 0xacab.org/meskio/cicer
 go 1.14
 
 require (
+	github.com/dgrijalva/jwt-go v3.2.0+incompatible
 	github.com/gorilla/mux v1.8.0
 	github.com/olivere/env v1.1.0
+	golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2
 	gorm.io/driver/sqlite v1.1.2
 	gorm.io/gorm v1.20.1
 )
diff --git a/go.sum b/go.sum
index ebc7181..2708a9b 100644
--- a/go.sum
+++ b/go.sum
@@ -1,5 +1,8 @@
 github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc=
 github.com/andybalholm/cascadia v1.1.0/go.mod h1:GsXiBklL0woXo1j/WYWtSYYC4ouU9PqHO0sqidkEA4Y=
+github.com/dgrijalva/jwt-go v1.0.2 h1:KPldsxuKGsS2FPWsNeg9ZO18aCrGKujPoWXn2yo+KQM=
+github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=
+github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
 github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
 github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
 github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
@@ -10,11 +13,13 @@ github.com/mattn/go-sqlite3 v1.14.2 h1:A2EQLwjYf/hfYaM20FVjs1UewCTTFR7RmjEHkLjld
 github.com/mattn/go-sqlite3 v1.14.2/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus=
 github.com/olivere/env v1.1.0 h1:owp/uwMwhru5668JjMDp8UTG3JGT27GTCk4ufYQfaTw=
 github.com/olivere/env v1.1.0/go.mod h1:zaoXy53SjZfxqZBGiGrZCkuVLYPdwrc+vArPuUVhJdQ=
+golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M=
 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
 golang.org/x/net v0.0.0-20180218175443-cbe0f9307d01/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
 golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
 golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
 golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884=
 golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
 gorm.io/driver/sqlite v1.1.2 h1:6LsQVSO93WU4Xv2NTwIk2jE3bbKBLgMGmertBleuSTE=
diff --git a/main.go b/main.go
index 1028105..3959817 100644
--- a/main.go
+++ b/main.go
@@ -12,14 +12,15 @@ import (
 
 func main() {
 	var (
-		dbPath = flag.String("db-path", env.String("./test.db", "DB_PATH"), "Path where the sqlite will be located")
-		addr   = flag.String("addr", env.String(":8080", "HTTP_ADDR", "ADDR"), "Address where the http server will bind")
+		dbPath  = flag.String("db-path", env.String("./test.db", "DB_PATH"), "Path where the sqlite will be located")
+		addr    = flag.String("addr", env.String(":8080", "HTTP_ADDR", "ADDR"), "Address where the http server will bind")
+		signKey = flag.String("signkey", env.String("", "SIGNKEY"), "Sign key for authentication tokens. DO NOT LEAVE UNSET!!!")
 	)
 	flag.Parse()
 
 	r := mux.NewRouter()
 	apiRouter := r.PathPrefix("/api/").Subrouter()
-	err := api.Init(*dbPath, apiRouter)
+	err := api.Init(*dbPath, *signKey, apiRouter)
 	if err != nil {
 		log.Panicln("Can't open the database:", err)
 	}
-- 
GitLab