From 5c175bf8deb238c3a214ba13b3cdd8a74f0ad850 Mon Sep 17 00:00:00 2001
From: meskio <meskio@sindominio.net>
Date: Sun, 4 Oct 2020 17:09:15 +0200
Subject: [PATCH] Use DB transactions to update balance/create transaction

---
 api/purchase.go    | 39 +++++++++++++++++++----------------
 api/topup.go       | 12 +++--------
 api/transaction.go | 51 ++++++++++++++++++++++++++++------------------
 3 files changed, 55 insertions(+), 47 deletions(-)

diff --git a/api/purchase.go b/api/purchase.go
index 2d0e9c2..0eec3ed 100644
--- a/api/purchase.go
+++ b/api/purchase.go
@@ -2,6 +2,7 @@ package api
 
 import (
 	"encoding/json"
+	"fmt"
 	"log"
 	"net/http"
 	"time"
@@ -45,12 +46,6 @@ func (a *api) AddPurchase(num int, w http.ResponseWriter, req *http.Request) {
 		return
 	}
 
-	httpStatus := a.updateMemberBalance(num, -total)
-	if httpStatus != http.StatusOK {
-		w.WriteHeader(httpStatus)
-		return
-	}
-
 	transaction := Transaction{
 		MemberNum: num,
 		Date:      time.Now(),
@@ -58,20 +53,28 @@ func (a *api) AddPurchase(num int, w http.ResponseWriter, req *http.Request) {
 		Type:      "purchase",
 		Total:     -total,
 	}
-	err = a.db.Create(&transaction).Error
-	if err != nil {
-		log.Printf("Can't create purchase: %v\n%v", err, purchase)
-		w.WriteHeader(http.StatusInternalServerError)
-		return
-	}
-
-	for _, p := range purchase {
-		err := a.db.Model(&Product{}).
-			Where("code = ?", p.ProductCode).
-			Update("stock", gorm.Expr("stock - ?", p.Ammount)).Error
+	var httpStatus int
+	err = a.db.Transaction(func(tx *gorm.DB) error {
+		httpStatus, err = createTransaction(tx, &transaction)
 		if err != nil {
-			log.Printf("Can't update product stock %d-%d: %v", p.ProductCode, p.Ammount, err)
+			return err
+		}
+
+		for _, p := range purchase {
+			err := tx.Model(&Product{}).
+				Where("code = ?", p.ProductCode).
+				Update("stock", gorm.Expr("stock - ?", p.Ammount)).Error
+			if err != nil {
+				httpStatus = http.StatusInternalServerError
+				return fmt.Errorf("Can't update product stock %d-%d: %v", p.ProductCode, p.Ammount, err)
+			}
 		}
+		return nil
+	})
+	if err != nil {
+		log.Println(err)
+		w.WriteHeader(httpStatus)
+		return
 	}
 
 	w.Header().Set("Content-Type", "application/json")
diff --git a/api/topup.go b/api/topup.go
index 74ed855..4ee42a6 100644
--- a/api/topup.go
+++ b/api/topup.go
@@ -30,12 +30,6 @@ func (a *api) AddTopup(adminNum int, w http.ResponseWriter, req *http.Request) {
 		return
 	}
 
-	httpStatus := a.updateMemberBalance(topup.Member, topup.Ammount)
-	if httpStatus != http.StatusOK {
-		w.WriteHeader(httpStatus)
-		return
-	}
-
 	transaction := Transaction{
 		MemberNum: topup.Member,
 		Date:      time.Now(),
@@ -46,10 +40,10 @@ func (a *api) AddTopup(adminNum int, w http.ResponseWriter, req *http.Request) {
 		Type:  "topup",
 		Total: topup.Ammount,
 	}
-	err = a.db.Create(&transaction).Error
+	httpStatus, err := createTransaction(a.db, &transaction)
 	if err != nil {
-		log.Printf("Can't create topup: %v\n%v", err, transaction)
-		w.WriteHeader(http.StatusInternalServerError)
+		log.Println(err)
+		w.WriteHeader(httpStatus)
 		return
 	}
 
diff --git a/api/transaction.go b/api/transaction.go
index 0420e43..0ec5c63 100644
--- a/api/transaction.go
+++ b/api/transaction.go
@@ -2,6 +2,7 @@ package api
 
 import (
 	"encoding/json"
+	"fmt"
 	"log"
 	"net/http"
 	"strconv"
@@ -55,7 +56,7 @@ func (a *api) GetTransaction(num int, role string, w http.ResponseWriter, req *h
 			w.WriteHeader(http.StatusNotFound)
 			return
 		}
-		log.Printf("Can't get transaction %s: %v", vars["code"], err)
+		log.Printf("Can't get transaction %s: %v", vars["id"], err)
 		w.WriteHeader(http.StatusInternalServerError)
 		return
 	}
@@ -102,23 +103,33 @@ func (a *api) getTransactionsByMember(num int, w http.ResponseWriter, req *http.
 	}
 }
 
-func (a *api) updateMemberBalance(num int, ammount int) int {
-	var member Member
-	err := a.db.Where("num = ?", num).Find(&member).Error
-	if err != nil {
-		log.Printf("Can't find member %d: %v", num, err)
-		return http.StatusNotAcceptable
-	}
-	if member.Balance < -ammount {
-		log.Printf("Member %d don't have enough money (%d-%d)", num, member.Balance, ammount)
-		return http.StatusBadRequest
-	}
-	err = a.db.Model(&Member{}).
-		Where("num = ?", num).
-		Update("balance", gorm.Expr("balance + ?", ammount)).Error
-	if err != nil {
-		log.Printf("Can't update update member balance %d-%d: %v", num, ammount, err)
-		return http.StatusNotAcceptable
-	}
-	return http.StatusOK
+func createTransaction(db *gorm.DB, transaction *Transaction) (httpStatus int, err error) {
+	httpStatus = http.StatusOK
+	err = db.Transaction(func(tx *gorm.DB) error {
+		var member Member
+		err := tx.Where("num = ?", transaction.MemberNum).Find(&member).Error
+		if err != nil {
+			httpStatus = http.StatusNotAcceptable
+			return fmt.Errorf("Can't find member %d: %v", transaction.MemberNum, err)
+		}
+		if member.Balance < -transaction.Total {
+			httpStatus = http.StatusBadRequest
+			return fmt.Errorf("Member %d don't have enough money (%d-%d)", member.Num, member.Balance, transaction.Total)
+		}
+		err = tx.Model(&Member{}).
+			Where("num = ?", transaction.MemberNum).
+			Update("balance", gorm.Expr("balance + ?", transaction.Total)).Error
+		if err != nil {
+			httpStatus = http.StatusNotAcceptable
+			fmt.Errorf("Can't update update member balance %d-%d: %v", member.Num, transaction.Total, err)
+		}
+
+		err = tx.Create(&transaction).Error
+		if err != nil {
+			httpStatus = http.StatusInternalServerError
+			return fmt.Errorf("Can't create transaction: %v\n%v", err, transaction)
+		}
+		return nil
+	})
+	return
 }
-- 
GitLab