diff --git a/api/api.go b/api/api.go index 07679ce59ce3e31939ef8825714a448f34a8e616..57df5776f5a7138a86eaed19e70ff23caea66624 100644 --- a/api/api.go +++ b/api/api.go @@ -3,35 +3,23 @@ package api import ( "log" + "0xacab.org/meskio/cicer/api/db" "github.com/gorilla/mux" - "gorm.io/driver/sqlite" - "gorm.io/gorm" ) type api struct { - db *gorm.DB + db *db.DB signKey []byte mail *Mail } -func initDB(dbPath string) (*gorm.DB, error) { - db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{}) - if err != nil { - return nil, err - } - - db.AutoMigrate(&Member{}, &Product{}, &Purchase{}, &Topup{}, &Transaction{}, - &OrderPurchase{}, &Order{}, &PasswordReset{}) - return db, err -} - func Init(dbPath string, signKey string, mail *Mail, r *mux.Router) error { - db, err := initDB(dbPath) + database, err := db.Init(dbPath) if err != nil { return err } - a := api{db, []byte(signKey), mail} + a := api{database, []byte(signKey), mail} go a.refundOrders() go a.cleanPaswordResets() diff --git a/api/auth.go b/api/auth.go index 53de1d26f5c09980c3ddb2e25900a522c58306d2..c036561e70d45e39948adfc4bb6c9e7d9c349b2e 100644 --- a/api/auth.go +++ b/api/auth.go @@ -1,9 +1,6 @@ package api import ( - "crypto/rand" - "crypto/subtle" - "encoding/base64" "encoding/json" "errors" "log" @@ -11,23 +8,11 @@ import ( "net/url" "time" + "0xacab.org/meskio/cicer/api/db" "github.com/dgrijalva/jwt-go" "github.com/gorilla/mux" - "golang.org/x/crypto/argon2" - "gorm.io/gorm" ) -const ( - timeExpireResetToken = 2 * 24 * time.Hour -) - -type PasswordReset struct { - gorm.Model - Token string `gorm:"unique;index"` - MemberNum int `gorm:"column:member"` - Member *Member `gorm:"foreignKey:MemberNum;references:Num"` -} - type creds struct { Login string `json:"login"` Password string `json:"password"` @@ -50,16 +35,10 @@ func (a *api) SignIn(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusInternalServerError) return } - var member Member - err = a.db.Where("login = ?", c.Login).First(&member).Error - if err != nil { - log.Printf("Can't locate user %s: %v", c.Login, err) - w.WriteHeader(http.StatusBadRequest) - return - } - if !passwordValid(c.Password, member) { - log.Printf("Invalid pass for %s", c.Login) + member, err := a.db.Login(c.Login, c.Password) + if err != nil { + log.Printf("Invalid pass for %s: %v", c.Login, err) w.WriteHeader(http.StatusBadRequest) return } @@ -127,38 +106,23 @@ func (a *api) SendPasswordReset(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusInternalServerError) return } - var member Member - err = a.db.Where("email = ?", reset.Email).First(&member).Error - if err != nil { - log.Printf("Can't locate user %s: %v", reset.Email, err) - w.WriteHeader(http.StatusBadRequest) - return - } - - tokenBytes := make([]byte, 15) - _, err = rand.Read(tokenBytes) + member, token, err := a.db.NewPasswordReset(reset.Email) if err != nil { - log.Printf("Can't generate a random token for password reset: %v", err) + if errors.Is(err, db.ErrorNotFound) { + w.WriteHeader(http.StatusBadRequest) + return + } + log.Printf("Error creating password reset: %v", err) w.WriteHeader(http.StatusInternalServerError) return } - token := base64.URLEncoding.EncodeToString(tokenBytes) - passwordReset := PasswordReset{ - Token: token, - MemberNum: member.Num, - } + url := url.URL{ Scheme: req.URL.Scheme, Host: req.URL.Host, Path: "/api/reset/" + token, } - err = a.db.Transaction(func(tx *gorm.DB) error { - err = tx.Create(&passwordReset).Error - if err != nil { - return err - } - return a.mail.sendPasswordReset(member, url.String()) - }) + err = a.mail.sendPasswordReset(member, url.String()) if err != nil { log.Printf("Error sending password reset: %v", err) w.WriteHeader(http.StatusInternalServerError) @@ -169,28 +133,24 @@ func (a *api) SendPasswordReset(w http.ResponseWriter, req *http.Request) { } func (a *api) ValidatePasswordReset(w http.ResponseWriter, req *http.Request) { - passwordReset, status := a.getPasswordReset(req) - if status != http.StatusOK { - w.WriteHeader(status) + vars := mux.Vars(req) + token := vars["token"] + err := a.db.ValidPasswordReset(token) + if err != nil { + log.Printf("Can't get password reset %s: %v", token, err) + if errors.Is(err, db.ErrorNotFound) { + w.WriteHeader(http.StatusNotFound) + } else { + w.WriteHeader(http.StatusInternalServerError) + } return } - w.Header().Set("Content-Type", "application/json") + w.Write([]byte("Valid token")) w.WriteHeader(http.StatusOK) - err := json.NewEncoder(w).Encode(passwordReset.Member) - if err != nil { - log.Printf("Can't encode reset member: %v", err) - w.WriteHeader(http.StatusInternalServerError) - return - } } func (a *api) PasswordReset(w http.ResponseWriter, req *http.Request) { - passwordReset, status := a.getPasswordReset(req) - if status != http.StatusOK { - w.WriteHeader(status) - return - } var reset passwordResetPut err := json.NewDecoder(req.Body).Decode(&reset) if err != nil { @@ -198,51 +158,21 @@ func (a *api) PasswordReset(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusInternalServerError) return } - hash, salt, err := newHashPass(reset.Password) - if err != nil { - log.Printf("Can't hash password: %v", err) - w.WriteHeader(http.StatusInternalServerError) - return - } - - err = a.db.Transaction(func(tx *gorm.DB) error { - err := a.db.Model(&Member{}). - Updates(Member{ - PassHash: hash, - Salt: salt, - }).Error - if err != nil { - return err - } - return a.db.Delete(passwordReset).Error - }) - if err != nil { - log.Printf("Error updating password: %v", err) - w.WriteHeader(http.StatusInternalServerError) - return - } - w.Write([]byte("Email sent")) - w.WriteHeader(http.StatusAccepted) -} -func (a *api) getPasswordReset(req *http.Request) (PasswordReset, int) { vars := mux.Vars(req) token := vars["token"] - - var passwordReset PasswordReset - err := a.db.Where("token = ?", token). - Preload("Member"). - First(&passwordReset).Error - status := http.StatusOK + err = a.db.ResetPassword(token, reset.Password) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - status = http.StatusNotFound + log.Printf("Can't reset password %s: %v", token, err) + if errors.Is(err, db.ErrorNotFound) { + w.WriteHeader(http.StatusNotFound) } else { - log.Printf("Can't get password reset %s: %v", token, err) - status = http.StatusInternalServerError + w.WriteHeader(http.StatusInternalServerError) } + return } - return passwordReset, status + w.Write([]byte("Email sent")) + w.WriteHeader(http.StatusAccepted) } func (a *api) cleanPaswordResets() { @@ -250,18 +180,7 @@ func (a *api) cleanPaswordResets() { const refundSleeptime = 10 * time.Minute for { time.Sleep(refundSleeptime) - a.cleanReset() - } -} - -func (a *api) cleanReset() { - t := time.Now().Add(timeExpireResetToken) - res := a.db.Where("created_at < ?", true, t). - Delete(&PasswordReset{}) - if res.Error != nil { - log.Println("Error deleting old reset tokens:", res.Error) - } else if res.RowsAffected != 0 { - log.Println("Deleted", res.RowsAffected, "password reset tokens") + a.db.CleanPasswordReset() } } @@ -391,30 +310,3 @@ func (a *api) validateToken(token string) (bool, jwt.MapClaims) { } return time.Unix(int64(exp), 0).After(time.Now()), claims } - -func newHashPass(password string) (hash []byte, salt []byte, err error) { - salt = make([]byte, 16) - _, err = rand.Read(salt) - if err != nil { - return - } - - hash = hashPass(password, salt) - return -} - -func passwordValid(password string, member Member) bool { - hash := hashPass(password, member.Salt) - return subtle.ConstantTimeCompare(hash, member.PassHash) == 1 -} - -func hashPass(password string, salt []byte) []byte { - const ( - time = 1 - memory = 64 * 1024 - threads = 2 - keyLen = 32 - ) - - return argon2.IDKey([]byte(password), salt, time, memory, threads, keyLen) -} diff --git a/api/auth_test.go b/api/auth_test.go index 5c49308fc84c15280bb5156f5a9351c75bb88349..4582b42bb542d04cfe4723e6c0db6befa687220d 100644 --- a/api/auth_test.go +++ b/api/auth_test.go @@ -3,6 +3,8 @@ package api import ( "net/http" "testing" + + "0xacab.org/meskio/cicer/api/db" ) func TestSignIn(t *testing.T) { @@ -17,8 +19,8 @@ func TestSignIn(t *testing.T) { } var respMember struct { - Token string `json:"token"` - Member Member `json:"member"` + Token string `json:"token"` + Member db.Member `json:"member"` } jsonAuth := creds{ Login: testMemberAdmin.Login, diff --git a/api/db/db.go b/api/db/db.go new file mode 100644 index 0000000000000000000000000000000000000000..6c0cd164f8602aca28b76e1ae14a0aad5cfb4742 --- /dev/null +++ b/api/db/db.go @@ -0,0 +1,21 @@ +package db + +import ( + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +type DB struct { + db *gorm.DB +} + +func Init(dbPath string) (*DB, error) { + db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{}) + if err != nil { + return nil, err + } + + db.AutoMigrate(&Member{}, &Product{}, &Purchase{}, &Topup{}, &Transaction{}, + &OrderPurchase{}, &Order{}, &PasswordReset{}) + return &DB{db}, err +} diff --git a/api/db/errors.go b/api/db/errors.go new file mode 100644 index 0000000000000000000000000000000000000000..e44f9c39d450edcdaee43cb75c0d8d12c7d417f6 --- /dev/null +++ b/api/db/errors.go @@ -0,0 +1,11 @@ +package db + +import ( + "errors" +) + +var ( + ErrorBadPassword = errors.New("Bad password") + ErrorInvalidRequest = errors.New("Invalid request") + ErrorNotFound = errors.New("Record not found") +) diff --git a/api/db/member.go b/api/db/member.go new file mode 100644 index 0000000000000000000000000000000000000000..62ea7e60e2005d04aacf4734edf8f2116314a63c --- /dev/null +++ b/api/db/member.go @@ -0,0 +1,228 @@ +package db + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "errors" + "log" + "time" + + "golang.org/x/crypto/argon2" + "gorm.io/gorm" +) + +const ( + timeExpireResetToken = 2 * 24 * time.Hour +) + +type Member struct { + gorm.Model `json:"-"` + Num int `json:"num" gorm:"unique;index"` + Login string `json:"login" gorm:"unique;index"` + Name string `json:"name"` + Email string `json:"email"` + Phone string `json:"phone"` + Balance int `json:"balance"` + Role string `json:"role"` + PassHash []byte `json:"-"` + Salt []byte `json:"-"` +} + +type PasswordReset struct { + gorm.Model + Token string `gorm:"unique;index"` + MemberNum int `gorm:"column:member"` + Member *Member `gorm:"foreignKey:MemberNum;references:Num"` +} + +type MemberReq struct { + Member + OldPassword string `json:"old_password"` + Password string `json:"password"` +} + +func (d DB) AddMember(memberReq *MemberReq) (member Member, err error) { + member.Num = memberReq.Num + member.Login = memberReq.Login + member.Name = memberReq.Name + member.Email = memberReq.Email + member.Phone = memberReq.Phone + member.Balance = memberReq.Balance + member.Role = memberReq.Role + + member.PassHash, member.Salt, err = newHashPass(memberReq.Password) + if err != nil { + return + } + + err = d.db.Create(&member).Error + return +} + +func (d DB) ListMembers() (members []Member, err error) { + err = d.db.Find(&members).Error + return +} + +func (d DB) GetMember(num int) (member Member, err error) { + err = d.db.Where("num = ?", num).First(&member).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + err = ErrorNotFound + } + return +} + +func (d DB) DeleteMember(num int) error { + return d.db.Where("num = ?", num).Delete(&Member{}).Error +} + +func (d DB) UpdateMember(num int, member MemberReq, checkPass bool) (Member, error) { + var dbMember Member + err := d.db.Where("num = ?", num).First(&dbMember).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + err = ErrorNotFound + } + return dbMember, err + } + if checkPass && !passwordValid(member.OldPassword, dbMember) { + return dbMember, ErrorBadPassword + } + + if member.Num != 0 { + dbMember.Num = member.Num + } + if member.Login != "" { + dbMember.Login = member.Login + } + if member.Name != "" { + dbMember.Name = member.Name + } + if member.Email != "" { + dbMember.Email = member.Email + } + if member.Role != "" { + dbMember.Role = member.Role + } + if member.Password != "" { + dbMember.PassHash, dbMember.Salt, err = newHashPass(member.Password) + if err != nil { + return dbMember, err + } + } + err = d.db.Save(&dbMember).Error + return dbMember, err +} + +func (d DB) Login(login, password string) (member Member, err error) { + err = d.db.Where("login = ?", login).First(&member).Error + if err != nil { + return + } + + if !passwordValid(password, member) { + err = ErrorBadPassword + } + return +} + +func (d DB) NewPasswordReset(email string) (member Member, token string, err error) { + err = d.db.Where("email = ?", email).First(&member).Error + if err != nil { + log.Printf("Can't locate user %s: %v", email, err) + err = ErrorNotFound + return + } + + tokenBytes := make([]byte, 15) + _, err = rand.Read(tokenBytes) + if err != nil { + log.Printf("Can't generate a random token for password reset: %v", err) + return + } + token = base64.URLEncoding.EncodeToString(tokenBytes) + passwordReset := PasswordReset{ + Token: token, + MemberNum: member.Num, + } + err = d.db.Create(&passwordReset).Error + return +} + +func (d *DB) ValidPasswordReset(token string) error { + _, err := d.getPasswordReset(token) + return err +} + +func (d *DB) ResetPassword(token, password string) error { + passwordReset, err := d.getPasswordReset(token) + if err != nil { + return err + } + + hash, salt, err := newHashPass(password) + if err != nil { + return err + } + + return d.db.Transaction(func(tx *gorm.DB) error { + err := tx.Model(&Member{}). + Updates(Member{ + PassHash: hash, + Salt: salt, + }).Error + if err != nil { + return err + } + return tx.Delete(passwordReset).Error + }) +} + +func (d *DB) getPasswordReset(token string) (passwordReset PasswordReset, err error) { + err = d.db.Where("token = ?", token). + Preload("Member"). + First(&passwordReset).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + err = ErrorNotFound + } + return +} + +func (d *DB) CleanPasswordReset() { + t := time.Now().Add(timeExpireResetToken) + res := d.db.Where("created_at < ?", true, t). + Delete(&PasswordReset{}) + if res.Error != nil { + log.Println("Error deleting old reset tokens:", res.Error) + } else if res.RowsAffected != 0 { + log.Println("Deleted", res.RowsAffected, "password reset tokens") + } +} + +func newHashPass(password string) (hash []byte, salt []byte, err error) { + salt = make([]byte, 16) + _, err = rand.Read(salt) + if err != nil { + return + } + + hash = hashPass(password, salt) + return +} + +func passwordValid(password string, member Member) bool { + hash := hashPass(password, member.Salt) + return subtle.ConstantTimeCompare(hash, member.PassHash) == 1 +} + +func hashPass(password string, salt []byte) []byte { + const ( + time = 1 + memory = 64 * 1024 + threads = 2 + keyLen = 32 + ) + + return argon2.IDKey([]byte(password), salt, time, memory, threads, keyLen) +} diff --git a/api/db/order.go b/api/db/order.go new file mode 100644 index 0000000000000000000000000000000000000000..d1e3335a734de74ab6d3a4a7fe672fab7402a887 --- /dev/null +++ b/api/db/order.go @@ -0,0 +1,167 @@ +package db + +import ( + "errors" + "log" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +type Order struct { + gorm.Model + Name string `json:"name"` + Description string `json:"description"` + MemberNum int `json:"member_num" gorm:"column:member"` + Member *Member `json:"member,omitempty" gorm:"foreignKey:MemberNum;references:Num"` + Deadline time.Time `json:"deadline"` + Active bool `json:"active" gorm:"index"` + + Products []Product `json:"products" gorm:"many2many:order_products;References:Code;JoinReferences:ProductCode"` + Transactions []Transaction `json:"transactions" gorm:"foreignKey:OrderID"` + TransactionID *uint `json:"-" gorm:"column:transaction"` +} + +type OrderPurchase struct { + gorm.Model `json:"-"` + TransactionID uint `json:"-"` + ProductCode int `json:"product_code"` + Product *Product `json:"product" gorm:"foreignKey:ProductCode;references:Code"` + Price int `json:"price"` + Amount int `json:"amount"` +} + +func (d *DB) ListOrders(active bool) (orders []Order, err error) { + query := d.db.Preload(clause.Associations). + Preload("Transactions.OrderPurchase") + if active { + query = query.Where("active = ?", true) + } + err = query.Order("deadline desc"). + Find(&orders).Error + return +} + +func (d *DB) GetOrder(memberNum int, id int) (order Order, transaction Transaction, err error) { + err = d.db.Preload(clause.Associations). + Preload("Transactions.OrderPurchase"). + Preload("Transactions.Member"). + First(&order, id).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + err = ErrorNotFound + return + } + err = d.db.Where("member = ? AND type = 'order' AND order_id = ?", memberNum, id). + Preload("OrderPurchase.Product"). + Find(&transaction).Error + return +} + +func (d *DB) AddOrder(order *Order) error { + return d.db.Create(&order).Error +} + +func (d *DB) AddOrderPurchase(memberNum int, orderID uint, purchase []OrderPurchase) (transaction Transaction, err error) { + var order Order + err = d.db.Preload("Products"). + Preload("Transactions"). + First(&order, orderID).Error + if err != nil { + return + } + if !order.Active { + err = ErrorInvalidRequest + log.Printf("Order is not active %d: %v", order.ID, purchase) + return + } + for _, t := range order.Transactions { + if t.MemberNum == memberNum { + log.Printf("Purchase by %d for %d when there is already one by this member: %v", memberNum, order.ID, purchase) + err = ErrorInvalidRequest + return + } + } + + total := 0 + for i, p := range purchase { + found := false + for _, product := range order.Products { + if product.Code == p.ProductCode { + total += product.Price * p.Amount + purchase[i].Price = product.Price + found = true + break + } + } + + if !found { + log.Printf("Order purchase product %d not in order: %v", p.ProductCode, purchase) + err = ErrorInvalidRequest + return + } + } + + transaction = Transaction{ + MemberNum: memberNum, + Total: -total, + Type: "order", + Date: time.Now(), + OrderPurchase: purchase, + OrderID: &order.ID, + } + err = createTransaction(d.db, &transaction) + return +} + +func (d *DB) DeactivateOrders() []Order { + var orders []Order + now := time.Now() + t := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.Local) + err := d.db.Where("active = ? AND deadline < ?", true, t). + Preload("Member"). + Preload("Transactions.OrderPurchase.Product"). + Preload("Transactions.Member"). + Find(&orders).Error + if err != nil { + log.Println("Error refunding orders:", err) + return []Order{} + } + + var deactivatedOrders []Order + for _, order := range orders { + total := 0 + for _, transaction := range order.Transactions { + for _, purchase := range transaction.OrderPurchase { + total += purchase.Price * purchase.Amount + } + } + + transaction := Transaction{ + MemberNum: order.MemberNum, + Date: time.Now(), + Type: "refund", + Total: total, + } + err = d.db.Transaction(func(tx *gorm.DB) error { + err := createTransaction(tx, &transaction) + if err != nil { + return err + } + return tx.Model(&Order{}). + Where("id = ?", order.ID). + Updates(map[string]interface{}{ + "active": false, + "transaction": transaction.ID}). + Error + }) + if err != nil { + log.Printf("Can't create refund: %v\n%v", err, order) + continue + } + + deactivatedOrders = append(deactivatedOrders, order) + log.Println("Refund order", order.Name, total) + } + return deactivatedOrders +} diff --git a/api/db/product.go b/api/db/product.go new file mode 100644 index 0000000000000000000000000000000000000000..238af810283f173c4374d126ad967b952d11bd23 --- /dev/null +++ b/api/db/product.go @@ -0,0 +1,62 @@ +package db + +import ( + "errors" + + "gorm.io/gorm" +) + +type Product struct { + gorm.Model `json:"-"` + Code int `json:"code" gorm:"unique;index"` + Name string `json:"name" gorm:"unique;index"` + Price int `json:"price"` + Stock int `json:"stock"` +} + +func (d *DB) AddProduct(product *Product) error { + return d.db.Create(&product).Error +} + +func (d *DB) ListProducts() (products []Product, err error) { + err = d.db.Find(&products).Error + return +} + +func (d *DB) GetProduct(code int) (product Product, err error) { + err = d.db.Where("code = ?", code).First(&product).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + err = ErrorNotFound + } + return +} + +func (d *DB) DeleteProduct(code int) error { + return d.db.Where("code = ?", code). + Delete(&Product{}).Error +} + +func (d *DB) UpdateProduct(code int, product *Product) (dbProduct Product, err error) { + err = d.db.Where("code = ?", code).First(&dbProduct).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + err = ErrorNotFound + } + return + } + + if product.Code != 0 { + dbProduct.Code = product.Code + } + if product.Name != "" { + dbProduct.Name = product.Name + } + if product.Price >= 0 { + dbProduct.Price = product.Price + } + if product.Stock >= 0 { + dbProduct.Stock = product.Stock + } + err = d.db.Save(&dbProduct).Error + return +} diff --git a/api/db/transaction.go b/api/db/transaction.go new file mode 100644 index 0000000000000000000000000000000000000000..408956e154f726ef1def0e74264a4d7170a16535 --- /dev/null +++ b/api/db/transaction.go @@ -0,0 +1,160 @@ +package db + +import ( + "errors" + "fmt" + "log" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +type Transaction struct { + gorm.Model + MemberNum int `json:"-" gorm:"column:member"` + Member *Member `json:"member,omitempty" gorm:"foreignKey:MemberNum;references:Num"` + Date time.Time `json:"date"` + Total int `json:"total"` + Type string `json:"type"` + + Purchase []Purchase `json:"purchase,omitempty"` + Topup *Topup `json:"topup,omitempty"` + OrderPurchase []OrderPurchase `json:"order_purchase,omitempty" gorm:"foreignKey:TransactionID"` + Order *Order `json:"order,omitempty"` + OrderID *uint `json:"-"` + Refund *Order `json:"refund,omitempty" gorm:"foreignKey:TransactionID"` +} + +type Topup struct { + gorm.Model `json:"-"` + TransactionID uint `json:"-" gorm:"column:transaction"` + MemberNum int `json:"member" gorm:"column:member"` + Member Member `json:"-" gorm:"foreignKey:MemberNum;references:Num"` + Comment string `json:"comment"` +} + +type Purchase struct { + gorm.Model `json:"-"` + TransactionID uint `json:"-" gorm:"column:transaction"` + ProductCode int `json:"code" gorm:"column:product"` + Product Product `json:"product" gorm:"foreignKey:ProductCode;references:Code"` + Price int `json:"price"` + Amount int `json:"amount"` +} + +func (d *DB) ListTransactions() (transactions []Transaction, err error) { + err = d.transactionQuery(). + Order("date desc"). + Find(&transactions).Error + return +} + +func (d *DB) TransactionByMember(num int) (transactions []Transaction, err error) { + err = d.transactionQuery(). + Where("member = ?", num). + Order("date desc"). + Find(&transactions).Error + return +} + +func (d *DB) GetTransaction(id int) (transaction Transaction, err error) { + err = d.transactionQuery(). + First(&transaction, id).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + err = ErrorNotFound + } + return +} + +func (d *DB) AddTopup(adminNum int, memberNum int, amount int, comment string) (transaction Transaction, err error) { + transaction = Transaction{ + MemberNum: memberNum, + Date: time.Now(), + Topup: &Topup{ + MemberNum: adminNum, + Comment: comment, + }, + Type: "topup", + Total: amount, + } + err = createTransaction(d.db, &transaction) + return +} + +func (d *DB) AddPurchase(memberNum int, purchase []Purchase) (transaction Transaction, err error) { + total := 0 + for i, p := range purchase { + var product Product + err = d.db.Where("code = ?", p.ProductCode).First(&product).Error + if err != nil { + log.Printf("Can't get product %d: %v", p.ProductCode, err) + err = ErrorNotFound + return + } + + total += product.Price * p.Amount + purchase[i].Price = product.Price + } + if total == 0 { + log.Printf("Empty purchase (%d)", memberNum) + err = ErrorInvalidRequest + return + } + + transaction = Transaction{ + MemberNum: memberNum, + Date: time.Now(), + Purchase: purchase, + Type: "purchase", + Total: -total, + } + err = d.db.Transaction(func(tx *gorm.DB) error { + err := createTransaction(tx, &transaction) + if err != nil { + return err + } + + for _, p := range purchase { + err = tx.Model(&Product{}). + Where("code = ?", p.ProductCode). + Update("stock", gorm.Expr("stock - ?", p.Amount)).Error + if err != nil { + return fmt.Errorf("Can't update product stock %d-%d: %v", p.ProductCode, p.Amount, err) + } + } + return nil + }) + return +} + +func (d *DB) transactionQuery() *gorm.DB { + return d.db.Preload("Purchase.Product"). + Preload("Order.Products"). + Preload("OrderPurchase.Product"). + Preload(clause.Associations) +} + +func createTransaction(db *gorm.DB, transaction *Transaction) error { + return db.Transaction(func(tx *gorm.DB) error { + var member Member + err := tx.Where("num = ?", transaction.MemberNum).Find(&member).Error + if err != nil { + log.Printf("Can't find member for transaction %d: %v", transaction.MemberNum, err) + return ErrorNotFound + } + if member.Balance < -transaction.Total { + log.Printf("Member %d don't have enough money (%d-%d)", member.Num, member.Balance, transaction.Total) + return ErrorInvalidRequest + } + err = tx.Model(&Member{}). + Where("num = ?", transaction.MemberNum). + Update("balance", gorm.Expr("balance + ?", transaction.Total)).Error + if err != nil { + log.Printf("Can't update update member balance %d-%d: %v", member.Num, transaction.Total, err) + return err + } + + return tx.Create(&transaction).Error + }) +} diff --git a/api/mail.go b/api/mail.go index 7d904a35205718ebd4dcd26748d2ef043ba44075..8d425af87dbd8ddcb0715128dd7656214a657bbe 100644 --- a/api/mail.go +++ b/api/mail.go @@ -5,6 +5,8 @@ import ( "net/smtp" "strings" "text/template" + + "0xacab.org/meskio/cicer/api/db" ) const ( @@ -71,18 +73,18 @@ type orderData struct { MemberName string OrderName string Products map[string]int - Purchases map[string][]OrderPurchase + Purchases map[string][]db.OrderPurchase } -func (m Mail) sendOrder(to string, order *Order) error { +func (m Mail) sendOrder(to string, order *db.Order) error { if m.server == "" { return nil } products := make(map[string]int) - purchases := make(map[string][]OrderPurchase) + purchases := make(map[string][]db.OrderPurchase) for _, t := range order.Transactions { - var purchase []OrderPurchase + var purchase []db.OrderPurchase for _, p := range t.OrderPurchase { if p.Amount == 0 { continue @@ -116,7 +118,7 @@ type passwordResetData struct { Link string } -func (m Mail) sendPasswordReset(member Member, link string) error { +func (m Mail) sendPasswordReset(member db.Member, link string) error { if m.server == "" { return nil } diff --git a/api/member.go b/api/member.go index 1647c5339757dafb3c4d7644e7baf3d0a83a9a8b..e58af170aeb2831e81fd154f8239ca85f691b2f7 100644 --- a/api/member.go +++ b/api/member.go @@ -7,57 +7,19 @@ import ( "net/http" "strconv" + "0xacab.org/meskio/cicer/api/db" "github.com/gorilla/mux" - "gorm.io/gorm" ) -var ( - ErroBadPassword = errors.New("Bad password") -) - -type Member struct { - gorm.Model `json:"-"` - Num int `json:"num" gorm:"unique;index"` - Login string `json:"login" gorm:"unique;index"` - Name string `json:"name"` - Email string `json:"email"` - Phone string `json:"phone"` - Balance int `json:"balance"` - Role string `json:"role"` - PassHash []byte `json:"-"` - Salt []byte `json:"-"` -} - -type MemberReq struct { - Member - OldPassword string `json:"old_password"` - Password string `json:"password"` -} - func (a *api) AddMember(w http.ResponseWriter, req *http.Request) { - var memberReq MemberReq + var memberReq db.MemberReq err := json.NewDecoder(req.Body).Decode(&memberReq) if err != nil { log.Printf("Can't create member: %v", err) w.WriteHeader(http.StatusInternalServerError) return } - member := Member{ - Num: memberReq.Num, - Login: memberReq.Login, - Name: memberReq.Name, - Email: memberReq.Email, - Phone: memberReq.Phone, - Balance: memberReq.Balance, - Role: memberReq.Role, - } - member.PassHash, member.Salt, err = newHashPass(memberReq.Password) - if err != nil { - log.Printf("Can't hash new member: %v\n%v", err, member) - w.WriteHeader(http.StatusInternalServerError) - return - } - err = a.db.Create(&member).Error + member, err := a.db.AddMember(&memberReq) if err != nil { log.Printf("Can't create member: %v\n%v", err, member) w.WriteHeader(http.StatusInternalServerError) @@ -75,8 +37,7 @@ func (a *api) AddMember(w http.ResponseWriter, req *http.Request) { } func (a *api) ListMembers(w http.ResponseWriter, req *http.Request) { - var members []Member - err := a.db.Find(&members).Error + members, err := a.db.ListMembers() if err != nil { log.Printf("Can't list members: %v", err) w.WriteHeader(http.StatusInternalServerError) @@ -99,10 +60,9 @@ func (a *api) GetMember(w http.ResponseWriter, req *http.Request) { } func (a *api) getMemberNum(num int, w http.ResponseWriter, req *http.Request) { - var member Member - err := a.db.Where("num = ?", num).First(&member).Error + member, err := a.db.GetMember(num) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { + if errors.Is(err, db.ErrorNotFound) { w.WriteHeader(http.StatusNotFound) return } @@ -122,7 +82,8 @@ func (a *api) getMemberNum(num int, w http.ResponseWriter, req *http.Request) { func (a *api) DeleteMember(w http.ResponseWriter, req *http.Request) { vars := mux.Vars(req) - err := a.db.Where("num = ?", vars["num"]).Delete(&Member{}).Error + num, _ := strconv.Atoi(vars["num"]) + err := a.db.DeleteMember(num) if err != nil { log.Printf("Can't delete member %s: %v", vars["num"], err) w.WriteHeader(http.StatusInternalServerError) @@ -132,7 +93,7 @@ func (a *api) DeleteMember(w http.ResponseWriter, req *http.Request) { } func (a *api) UpdateMember(w http.ResponseWriter, req *http.Request) { - var member MemberReq + var member db.MemberReq err := json.NewDecoder(req.Body).Decode(&member) if err != nil { log.Printf("Can't decode member: %v", err) @@ -148,9 +109,9 @@ func (a *api) UpdateMember(w http.ResponseWriter, req *http.Request) { return } - m, err := a.updateMember(num, member, false) + m, err := a.db.UpdateMember(num, member, false) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { + if errors.Is(err, db.ErrorNotFound) { w.WriteHeader(http.StatusNotFound) return } @@ -170,7 +131,7 @@ func (a *api) UpdateMember(w http.ResponseWriter, req *http.Request) { } func (a *api) UpdateMemberMe(num int, w http.ResponseWriter, req *http.Request) { - var member MemberReq + var member db.MemberReq err := json.NewDecoder(req.Body).Decode(&member) if err != nil { log.Printf("Can't decode member: %v", err) @@ -180,11 +141,11 @@ func (a *api) UpdateMemberMe(num int, w http.ResponseWriter, req *http.Request) member.Num = 0 member.Role = "" - m, err := a.updateMember(num, member, true) + m, err := a.db.UpdateMember(num, member, true) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { + if errors.Is(err, db.ErrorNotFound) { w.WriteHeader(http.StatusNotFound) - } else if errors.Is(err, ErroBadPassword) { + } else if errors.Is(err, db.ErrorBadPassword) { w.WriteHeader(http.StatusBadRequest) } else { log.Printf("Can't update member %d: %v", num, err) @@ -202,38 +163,3 @@ func (a *api) UpdateMemberMe(num int, w http.ResponseWriter, req *http.Request) return } } - -func (a *api) updateMember(num int, member MemberReq, checkPass bool) (Member, error) { - var dbMember Member - err := a.db.Where("num = ?", num).First(&dbMember).Error - if err != nil { - return dbMember, err - } - if checkPass && !passwordValid(member.OldPassword, dbMember) { - return dbMember, ErroBadPassword - } - - if member.Num != 0 { - dbMember.Num = member.Num - } - if member.Login != "" { - dbMember.Login = member.Login - } - if member.Name != "" { - dbMember.Name = member.Name - } - if member.Email != "" { - dbMember.Email = member.Email - } - if member.Role != "" { - dbMember.Role = member.Role - } - if member.Password != "" { - dbMember.PassHash, dbMember.Salt, err = newHashPass(member.Password) - if err != nil { - return dbMember, err - } - } - err = a.db.Save(&dbMember).Error - return dbMember, err -} diff --git a/api/member_test.go b/api/member_test.go index b9c5b4753fe853e1dd8ef82cd7f73f0885e099eb..39e20431ae843f9b4f30d2a119c2a0da372bb34d 100644 --- a/api/member_test.go +++ b/api/member_test.go @@ -3,10 +3,12 @@ package api import ( "net/http" "testing" + + "0xacab.org/meskio/cicer/api/db" ) -var testMember = MemberReq{ - Member: Member{ +var testMember = db.MemberReq{ + Member: db.Member{ Num: 10, Login: "foo", Name: "Foo Baz", @@ -17,8 +19,8 @@ var testMember = MemberReq{ Password: "password", } -var testMemberAdmin = MemberReq{ - Member: Member{ +var testMemberAdmin = db.MemberReq{ + Member: db.Member{ Num: 20, Login: "bar", Name: "Bar Baz", @@ -34,7 +36,7 @@ func TestMemberAddList(t *testing.T) { defer tapi.close() tapi.addTestMember() - var members []Member + var members []db.Member resp := tapi.doAdmin("GET", "/member", nil, &members) if resp.StatusCode != http.StatusOK { t.Fatal("Can't get members:", resp.Status) @@ -61,7 +63,7 @@ func TestMemberGetDelete(t *testing.T) { } tapi.addTestMember() - var gotMember Member + var gotMember db.Member resp = tapi.doAdmin("GET", "/member/10", nil, &gotMember) if resp.StatusCode != http.StatusOK { t.Error("Can't find the member:", resp.Status) @@ -92,7 +94,7 @@ func TestMemberUpdate(t *testing.T) { t.Fatal("Can't update member:", resp.Status) } - var gotMember Member + var gotMember db.Member resp = tapi.doAdmin("GET", "/member/10", nil, &gotMember) if resp.StatusCode != http.StatusOK { t.Error("Can't find the member:", resp.Status) @@ -122,7 +124,7 @@ func TestMemberUpdateMe(t *testing.T) { t.Fatal("Can't update member:", resp.Status) } - var gotMember Member + var gotMember db.Member resp = tapi.doAdmin("GET", "/member/10", nil, &gotMember) if resp.StatusCode != http.StatusOK { t.Error("Can't find the member:", resp.Status) diff --git a/api/order.go b/api/order.go index 3ab98dea4bf4bd08bda77e172b99f5e9cb16401b..a296f50de53de3c1bcd17fc39e567f7d719be32d 100644 --- a/api/order.go +++ b/api/order.go @@ -5,104 +5,34 @@ import ( "errors" "log" "net/http" + "strconv" "time" + "0xacab.org/meskio/cicer/api/db" "github.com/gorilla/mux" - "gorm.io/gorm" - "gorm.io/gorm/clause" ) -type Order struct { - gorm.Model - Name string `json:"name"` - Description string `json:"description"` - MemberNum int `json:"member_num" gorm:"column:member"` - Member *Member `json:"member,omitempty" gorm:"foreignKey:MemberNum;references:Num"` - Deadline time.Time `json:"deadline"` - Active bool `json:"active" gorm:"index"` - - Products []Product `json:"products" gorm:"many2many:order_products;References:Code;JoinReferences:ProductCode"` - Transactions []Transaction `json:"transactions" gorm:"foreignKey:OrderID"` - TransactionID *uint `json:"-" gorm:"column:transaction"` -} - -type OrderPurchase struct { - gorm.Model `json:"-"` - TransactionID uint `json:"-"` - ProductCode int `json:"product_code"` - Product *Product `json:"product" gorm:"foreignKey:ProductCode;references:Code"` - Price int `json:"price"` - Amount int `json:"amount"` -} - type OrderGetResponse struct { - Order Order `json:"order"` - Transaction *Transaction `json:"transaction"` + Order db.Order `json:"order"` + Transaction *db.Transaction `json:"transaction"` } type OrderPurchaseRequest struct { - Purchase []OrderPurchase `json:"purchase"` - OrderID uint `json:"order"` + Purchase []db.OrderPurchase `json:"purchase"` + OrderID uint `json:"order"` } func (a *api) refundOrders() { const refundSleeptime = 10 * time.Minute for { time.Sleep(refundSleeptime) - a.deactivateOrders() - } -} - -func (a *api) deactivateOrders() { - var orders []Order - now := time.Now() - t := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.Local) - err := a.db.Where("active = ? AND deadline < ?", true, t). - Preload("Member"). - Preload("Transactions.OrderPurchase.Product"). - Preload("Transactions.Member"). - Find(&orders).Error - if err != nil { - log.Println("Error refunding orders:", err) - return - } - - for _, order := range orders { - total := 0 - for _, transaction := range order.Transactions { - for _, purchase := range transaction.OrderPurchase { - total += purchase.Price * purchase.Amount - } - } - - transaction := Transaction{ - MemberNum: order.MemberNum, - Date: time.Now(), - Type: "refund", - Total: total, - } - err = a.db.Transaction(func(tx *gorm.DB) error { - _, err := createTransaction(tx, &transaction) + orders := a.db.DeactivateOrders() + for _, order := range orders { + err := a.mail.sendOrder(order.Member.Email, &order) if err != nil { - return err + log.Println("Error sending order email:", err) } - return tx.Model(&Order{}). - Where("id = ?", order.ID). - Updates(map[string]interface{}{ - "active": false, - "transaction": transaction.ID}). - Error - }) - if err != nil { - log.Printf("Can't create refund: %v\n%v", err, order) - continue - } - - err := a.mail.sendOrder(order.Member.Email, &order) - if err != nil { - log.Println("Error sending order email:", err) } - log.Println("Refund order", order.Name, total) } } @@ -115,19 +45,13 @@ func (a *api) ListActiveOrders(w http.ResponseWriter, req *http.Request) { } func (a *api) listOrders(active bool, w http.ResponseWriter, req *http.Request) { - var orders []Order - query := a.db.Preload(clause.Associations). - Preload("Transactions.OrderPurchase") - if active { - query = query.Where("active = ?", true) - } - err := query.Order("deadline desc"). - Find(&orders).Error + orders, err := a.db.ListOrders(active) if err != nil { log.Printf("Can't list orders: %v", err) w.WriteHeader(http.StatusInternalServerError) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) err = json.NewEncoder(w).Encode(orders) @@ -139,13 +63,10 @@ func (a *api) listOrders(active bool, w http.ResponseWriter, req *http.Request) func (a *api) GetOrder(num int, w http.ResponseWriter, req *http.Request) { vars := mux.Vars(req) - var order Order - err := a.db.Preload(clause.Associations). - Preload("Transactions.OrderPurchase"). - Preload("Transactions.Member"). - First(&order, vars["id"]).Error + id, _ := strconv.Atoi(vars["id"]) + order, transaction, err := a.db.GetOrder(num, id) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { + if errors.Is(err, db.ErrorNotFound) { w.WriteHeader(http.StatusNotFound) return } @@ -153,18 +74,9 @@ func (a *api) GetOrder(num int, w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusInternalServerError) return } + var body OrderGetResponse body.Order = order - - var transaction Transaction - err = a.db.Where("member = ? AND type = 'order' AND order_id = ?", num, vars["id"]). - Preload("OrderPurchase.Product"). - Find(&transaction).Error - if err != nil { - log.Printf("Can't get order transaction %s: %v", vars["id"], err) - w.WriteHeader(http.StatusInternalServerError) - return - } if transaction.ID != 0 { body.Transaction = &transaction } @@ -180,7 +92,7 @@ func (a *api) GetOrder(num int, w http.ResponseWriter, req *http.Request) { } func (a *api) AddOrder(num int, w http.ResponseWriter, req *http.Request) { - var order Order + var order db.Order err := json.NewDecoder(req.Body).Decode(&order) if err != nil { log.Printf("Can't parse order: %v", err) @@ -190,7 +102,7 @@ func (a *api) AddOrder(num int, w http.ResponseWriter, req *http.Request) { order.MemberNum = num order.Active = true - err = a.db.Create(&order).Error + err = a.db.AddOrder(&order) if err != nil { log.Printf("Can't create order: %v", err) w.WriteHeader(http.StatusInternalServerError) @@ -221,59 +133,16 @@ func (a *api) AddOrderPurchase(num int, w http.ResponseWriter, req *http.Request return } - var order Order - err = a.db.Preload("Products"). - Preload("Transactions"). - First(&order, request.OrderID).Error + transaction, err := a.db.AddOrderPurchase(num, request.OrderID, request.Purchase) if err != nil { - log.Printf("Can't get order %d: %v", request.OrderID, err) - w.WriteHeader(http.StatusInternalServerError) - return - } - if !order.Active { - log.Printf("Order is not active %d: %v", order.ID, request) - w.WriteHeader(http.StatusBadRequest) - return - } - for _, t := range order.Transactions { - if t.MemberNum == num { - log.Printf("Purchase by %d for %d when there is already one by this member: %v", num, order.ID, request) - w.WriteHeader(http.StatusBadRequest) - return - } - } - - total := 0 - for i, p := range request.Purchase { - found := false - for _, product := range order.Products { - if product.Code == p.ProductCode { - total += product.Price * p.Amount - request.Purchase[i].Price = product.Price - found = true - break - } - } - - if !found { - log.Printf("Order purchase product %d not in order: %v", p.ProductCode, request) + if errors.Is(err, db.ErrorNotFound) { + w.WriteHeader(http.StatusNotAcceptable) + } else if errors.Is(err, db.ErrorInvalidRequest) { w.WriteHeader(http.StatusBadRequest) - return + } else { + log.Printf("Can't get order %d: %v", request.OrderID, err) + w.WriteHeader(http.StatusInternalServerError) } - } - - transaction := Transaction{ - MemberNum: num, - Total: -total, - Type: "order", - Date: time.Now(), - OrderPurchase: request.Purchase, - OrderID: &order.ID, - } - httpStatus, err := createTransaction(a.db, &transaction) - if err != nil { - log.Println(err) - w.WriteHeader(httpStatus) return } diff --git a/api/order_test.go b/api/order_test.go index e3a4ee0807357eb2c0f2dfff6f2344f2f847f8e1..af3626159820666a5eacc1d2f6206458bd6b5891 100644 --- a/api/order_test.go +++ b/api/order_test.go @@ -6,13 +6,15 @@ import ( "strconv" "testing" "time" + + "0xacab.org/meskio/cicer/api/db" ) -var testOrder = Order{ +var testOrder = db.Order{ Name: "huevos", Description: "huevos frescos", Deadline: time.Now().Add(24 * time.Hour), - Products: []Product{ + Products: []db.Product{ testProduct, }, } @@ -23,7 +25,7 @@ func TestOrderAddList(t *testing.T) { tapi.addTestMember() tapi.addTestOrder() - var orders []Order + var orders []db.Order resp := tapi.do("GET", "/order", nil, &orders) if resp.StatusCode != http.StatusOK { t.Fatal("Can't get orders:", resp.Status) @@ -49,7 +51,7 @@ func TestOrderActive(t *testing.T) { tapi.addTestMember() tapi.addTestOrder() - var orders []Order + var orders []db.Order resp := tapi.do("GET", "/order/active", nil, &orders) if resp.StatusCode != http.StatusOK { t.Fatal("Can't get orders:", resp.Status) @@ -69,7 +71,7 @@ func TestOrderPurchase(t *testing.T) { tapi.addTestMember() tapi.addTestOrder() - var orders []Order + var orders []db.Order resp := tapi.do("GET", "/order/active", nil, &orders) if resp.StatusCode != http.StatusOK { t.Fatal("Can't get orders:", resp.Status) @@ -77,7 +79,7 @@ func TestOrderPurchase(t *testing.T) { purchase := OrderPurchaseRequest{ OrderID: orders[0].ID, - Purchase: []OrderPurchase{ + Purchase: []db.OrderPurchase{ { ProductCode: testProduct.Code, Amount: 3, @@ -89,7 +91,7 @@ func TestOrderPurchase(t *testing.T) { t.Fatal("Can't create order:", resp.Status) } - var transactions []Transaction + var transactions []db.Transaction resp = tapi.do("GET", "/transaction/mine", nil, &transactions) if resp.StatusCode != http.StatusOK { t.Fatal("Can't get transactions:", resp.Status) @@ -102,7 +104,7 @@ func TestOrderPurchase(t *testing.T) { t.Fatal("Wrong total", transactions[0].Total) } - var member Member + var member db.Member resp = tapi.do("GET", "/member/me", nil, &member) if resp.StatusCode != http.StatusOK { t.Error("Can't find the member:", resp.Status) @@ -125,7 +127,7 @@ func TestOrderNoDeactivation(t *testing.T) { t.Fatal("Can't create order:", resp.Status) } - var orders []Order + var orders []db.Order resp = tapi.do("GET", "/order/active", nil, &orders) if resp.StatusCode != http.StatusOK { t.Fatal("Can't get transactions:", resp.Status) @@ -135,12 +137,14 @@ func TestOrderNoDeactivation(t *testing.T) { } dbPath := path.Join(tapi.testPath, "test.db") - db, err := initDB(dbPath) + database, err := db.Init(dbPath) if err != nil { t.Fatal("Can't initialize the db:", err) } - a := api{db: db, mail: NewMail("", "", "")} - a.deactivateOrders() + orders = database.DeactivateOrders() + if len(orders) != 0 { + t.Error("Deactivated some orders:", orders) + } resp = tapi.do("GET", "/order/active", nil, &orders) if resp.StatusCode != http.StatusOK { @@ -163,7 +167,7 @@ func TestOrderDeactivation(t *testing.T) { t.Fatal("Can't create order:", resp.Status) } - var orders []Order + var orders []db.Order resp = tapi.do("GET", "/order/active", nil, &orders) if resp.StatusCode != http.StatusOK { t.Fatal("Can't get transactions:", resp.Status) @@ -174,7 +178,7 @@ func TestOrderDeactivation(t *testing.T) { purchase := OrderPurchaseRequest{ OrderID: orders[0].ID, - Purchase: []OrderPurchase{ + Purchase: []db.OrderPurchase{ { ProductCode: testProduct.Code, Amount: 3, @@ -187,12 +191,14 @@ func TestOrderDeactivation(t *testing.T) { } dbPath := path.Join(tapi.testPath, "test.db") - db, err := initDB(dbPath) + database, err := db.Init(dbPath) if err != nil { t.Fatal("Can't initialize the db:", err) } - a := api{db: db, mail: NewMail("", "", "")} - a.deactivateOrders() + orders = database.DeactivateOrders() + if len(orders) != 1 { + t.Error("Deactivated wrong orders:", orders) + } resp = tapi.do("GET", "/order/active", nil, &orders) if resp.StatusCode != http.StatusOK { @@ -204,7 +210,7 @@ func TestOrderDeactivation(t *testing.T) { total := 3 * testProduct.Price - var transactions []Transaction + var transactions []db.Transaction resp = tapi.do("GET", "/transaction/mine", nil, &transactions) if resp.StatusCode != http.StatusOK { t.Fatal("Can't get transactions:", resp.Status) @@ -219,7 +225,7 @@ func TestOrderDeactivation(t *testing.T) { t.Fatal("Wrong total:", transactions[0].Total) } - var member Member + var member db.Member resp = tapi.do("GET", "/member/me", nil, &member) if resp.StatusCode != http.StatusOK { t.Fatal("Can't member:", resp.Status) @@ -235,7 +241,7 @@ func TestGetOrder(t *testing.T) { tapi.addTestMember() tapi.addTestOrder() - var orders []Order + var orders []db.Order resp := tapi.do("GET", "/order/active", nil, &orders) if resp.StatusCode != http.StatusOK { t.Fatal("Can't get orders:", resp.Status) @@ -259,7 +265,7 @@ func TestGetOrder(t *testing.T) { purchase := OrderPurchaseRequest{ OrderID: orders[0].ID, - Purchase: []OrderPurchase{ + Purchase: []db.OrderPurchase{ { ProductCode: testProduct.Code, Amount: 3, diff --git a/api/product.go b/api/product.go index b3669b7cb73f815372eed71d710b057cc1f61397..4f679f82f6ec32ed73f5a7ee70726038bd38a664 100644 --- a/api/product.go +++ b/api/product.go @@ -5,28 +5,21 @@ import ( "errors" "log" "net/http" + "strconv" + "0xacab.org/meskio/cicer/api/db" "github.com/gorilla/mux" - "gorm.io/gorm" ) -type Product struct { - gorm.Model `json:"-"` - Code int `json:"code" gorm:"unique;index"` - Name string `json:"name" gorm:"unique;index"` - Price int `json:"price"` - Stock int `json:"stock"` -} - func (a *api) AddProduct(w http.ResponseWriter, req *http.Request) { - var product Product + var product db.Product err := json.NewDecoder(req.Body).Decode(&product) if err != nil { log.Printf("Can't create product: %v", err) w.WriteHeader(http.StatusInternalServerError) return } - err = a.db.Create(&product).Error + err = a.db.AddProduct(&product) if err != nil { log.Printf("Can't create product: %v\n%v", err, product) w.WriteHeader(http.StatusInternalServerError) @@ -44,8 +37,7 @@ func (a *api) AddProduct(w http.ResponseWriter, req *http.Request) { } func (a *api) ListProducts(w http.ResponseWriter, req *http.Request) { - var products []Product - err := a.db.Find(&products).Error + products, err := a.db.ListProducts() if err != nil { log.Printf("Can't list products: %v", err) w.WriteHeader(http.StatusInternalServerError) @@ -63,14 +55,14 @@ func (a *api) ListProducts(w http.ResponseWriter, req *http.Request) { func (a *api) GetProduct(w http.ResponseWriter, req *http.Request) { vars := mux.Vars(req) - var product Product - err := a.db.Where("code = ?", vars["code"]).First(&product).Error + code, _ := strconv.Atoi(vars["code"]) + product, err := a.db.GetProduct(code) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { + if errors.Is(err, db.ErrorNotFound) { w.WriteHeader(http.StatusNotFound) return } - log.Printf("Can't get product %s: %v", vars["code"], err) + log.Printf("Can't get product %d: %v", code, err) w.WriteHeader(http.StatusInternalServerError) return } @@ -86,9 +78,8 @@ func (a *api) GetProduct(w http.ResponseWriter, req *http.Request) { func (a *api) DeleteProduct(w http.ResponseWriter, req *http.Request) { vars := mux.Vars(req) - err := a.db.Unscoped(). - Where("code = ?", vars["code"]). - Delete(&Product{}).Error + code, _ := strconv.Atoi(vars["code"]) + err := a.db.DeleteProduct(code) if err != nil { log.Printf("Can't delete product %s: %v", vars["code"], err) w.WriteHeader(http.StatusInternalServerError) @@ -98,7 +89,7 @@ func (a *api) DeleteProduct(w http.ResponseWriter, req *http.Request) { } func (a *api) UpdateProduct(w http.ResponseWriter, req *http.Request) { - var product Product + var product db.Product err := json.NewDecoder(req.Body).Decode(&product) if err != nil { log.Printf("Can't decode product: %v", err) @@ -107,32 +98,13 @@ func (a *api) UpdateProduct(w http.ResponseWriter, req *http.Request) { } vars := mux.Vars(req) - var dbProduct Product - err = a.db.Where("code = ?", vars["code"]).First(&dbProduct).Error + code, _ := strconv.Atoi(vars["code"]) + dbProduct, err := a.db.UpdateProduct(code, &product) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { + if errors.Is(err, db.ErrorNotFound) { w.WriteHeader(http.StatusNotFound) return } - log.Printf("Can't get product %s: %v", vars["code"], err) - w.WriteHeader(http.StatusInternalServerError) - return - } - - if product.Code != 0 { - dbProduct.Code = product.Code - } - if product.Name != "" { - dbProduct.Name = product.Name - } - if product.Price >= 0 { - dbProduct.Price = product.Price - } - if product.Stock >= 0 { - dbProduct.Stock = product.Stock - } - err = a.db.Save(&dbProduct).Error - if err != nil { log.Printf("Can't update product %s: %v %v", vars["code"], err, product) w.WriteHeader(http.StatusInternalServerError) return diff --git a/api/product_test.go b/api/product_test.go index 024a39d9917fbf36a557e073013c2b7e1cb120db..09ccd9d5196f9215df5c5f3351a6fe36b108d45e 100644 --- a/api/product_test.go +++ b/api/product_test.go @@ -3,9 +3,11 @@ package api import ( "net/http" "testing" + + "0xacab.org/meskio/cicer/api/db" ) -var testProduct = Product{ +var testProduct = db.Product{ Code: 234, Name: "Aceite", Price: 1700, @@ -17,7 +19,7 @@ func TestProductAddList(t *testing.T) { defer tapi.close() tapi.addTestProducts() - var products []Product + var products []db.Product resp := tapi.do("GET", "/product", nil, &products) if resp.StatusCode != http.StatusOK { t.Fatal("Can't get products:", resp.Status) @@ -44,7 +46,7 @@ func TestProductGetDelete(t *testing.T) { } tapi.addTestProducts() - var gotProduct Product + var gotProduct db.Product resp = tapi.do("GET", "/product/234", nil, &gotProduct) if resp.StatusCode != http.StatusOK { t.Error("Can't find the product:", resp.Status) @@ -75,7 +77,7 @@ func TestProductUpdate(t *testing.T) { t.Fatal("Can't update product:", resp.Status) } - var gotProduct Product + var gotProduct db.Product resp = tapi.do("GET", "/product/234", nil, &gotProduct) if resp.StatusCode != http.StatusOK { t.Error("Can't find the product:", resp.Status) diff --git a/api/purchase.go b/api/purchase.go index 235e85e453d242b3298bfdf826d0f656d41e3102..bf8cc9c7039154caee1d5058877f7ded0767ddce 100644 --- a/api/purchase.go +++ b/api/purchase.go @@ -2,78 +2,32 @@ package api import ( "encoding/json" - "fmt" + "errors" "log" "net/http" - "time" - "gorm.io/gorm" + "0xacab.org/meskio/cicer/api/db" ) -type Purchase struct { - gorm.Model `json:"-"` - TransactionID uint `json:"-" gorm:"column:transaction"` - ProductCode int `json:"code" gorm:"column:product"` - Product Product `json:"product" gorm:"foreignKey:ProductCode;references:Code"` - Price int `json:"price"` - Amount int `json:"amount"` -} - func (a *api) AddPurchase(num int, w http.ResponseWriter, req *http.Request) { - var purchase []Purchase + var purchase []db.Purchase err := json.NewDecoder(req.Body).Decode(&purchase) if err != nil { log.Printf("Can't create purchase: %v", err) w.WriteHeader(http.StatusInternalServerError) return } - total := 0 - for i, p := range purchase { - var product Product - err = a.db.Where("code = ?", p.ProductCode).First(&product).Error - if err != nil { - log.Printf("Can't get product %d: %v", p.ProductCode, err) - w.WriteHeader(http.StatusNotAcceptable) - return - } - - total += product.Price * p.Amount - purchase[i].Price = product.Price - } - if total == 0 { - log.Printf("Empty purchase (%d)", num) - w.WriteHeader(http.StatusNotAcceptable) - return - } - - transaction := Transaction{ - MemberNum: num, - Date: time.Now(), - Purchase: purchase, - Type: "purchase", - Total: -total, - } - var httpStatus int - err = a.db.Transaction(func(tx *gorm.DB) error { - httpStatus, err = createTransaction(tx, &transaction) - if err != nil { - return err - } - for _, p := range purchase { - err := tx.Model(&Product{}). - Where("code = ?", p.ProductCode). - Update("stock", gorm.Expr("stock - ?", p.Amount)).Error - if err != nil { - httpStatus = http.StatusInternalServerError - return fmt.Errorf("Can't update product stock %d-%d: %v", p.ProductCode, p.Amount, err) - } - } - return nil - }) + transaction, err := a.db.AddPurchase(num, purchase) if err != nil { - log.Println(err) - w.WriteHeader(httpStatus) + if errors.Is(err, db.ErrorNotFound) { + w.WriteHeader(http.StatusNotAcceptable) + } else if errors.Is(err, db.ErrorInvalidRequest) { + w.WriteHeader(http.StatusBadRequest) + } else { + w.WriteHeader(http.StatusInternalServerError) + log.Printf("Can't create purchase: %v\n%v", err, transaction) + } return } diff --git a/api/purchase_test.go b/api/purchase_test.go index 0ba1285d0e39a30dfb68d126d2fe40dad4d58bd3..53a3cdb440f72d9ec0799427d4045c098a553291 100644 --- a/api/purchase_test.go +++ b/api/purchase_test.go @@ -3,6 +3,8 @@ package api import ( "net/http" "testing" + + "0xacab.org/meskio/cicer/api/db" ) func TestPurchaseAddListMine(t *testing.T) { @@ -11,7 +13,7 @@ func TestPurchaseAddListMine(t *testing.T) { tapi.addTestMember() tapi.addTestProducts() - products := []Purchase{ + products := []db.Purchase{ { ProductCode: testProduct.Code, Amount: 5, @@ -21,7 +23,7 @@ func TestPurchaseAddListMine(t *testing.T) { if resp.StatusCode != http.StatusCreated { t.Fatal("Can't create purchase:", resp.Status) } - var transactions []Transaction + var transactions []db.Transaction resp = tapi.do("GET", "/transaction/mine", nil, &transactions) if resp.StatusCode != http.StatusOK { t.Fatal("Can't get transactions:", resp.Status) @@ -43,7 +45,7 @@ func TestPurchaseAddListMine(t *testing.T) { t.Error("Wrong product price:", transactions[0].Purchase[0].Price) } - var product Product + var product db.Product resp = tapi.do("GET", "/product/234", nil, &product) if resp.StatusCode != http.StatusOK { t.Error("Can't find the product:", resp.Status) @@ -52,7 +54,7 @@ func TestPurchaseAddListMine(t *testing.T) { t.Error("Wrong product stock:", product) } - var member Member + var member db.Member resp = tapi.do("GET", "/member/me", nil, &member) if resp.StatusCode != http.StatusOK { t.Error("Can't find the member:", resp.Status) diff --git a/api/topup.go b/api/topup.go index 7db5cb4197532784704021cd915d0de72488cd9b..513afacc21ceffaa549c7e80ef2e5b84510aa55a 100644 --- a/api/topup.go +++ b/api/topup.go @@ -2,27 +2,21 @@ package api import ( "encoding/json" + "errors" "log" "net/http" - "time" - "gorm.io/gorm" + "0xacab.org/meskio/cicer/api/db" ) -type Topup struct { - gorm.Model `json:"-"` - TransactionID uint `json:"-" gorm:"column:transaction"` - MemberNum int `json:"member" gorm:"column:member"` - Member Member `json:"-" gorm:"foreignKey:MemberNum;references:Num"` - Comment string `json:"comment"` +type TopupReq struct { + Member int `json:"member"` + Comment string `json:"comment"` + Amount int `json:"amount"` } func (a *api) AddTopup(adminNum int, w http.ResponseWriter, req *http.Request) { - var topup struct { - Member int `json:"member"` - Comment string `json:"comment"` - Amount int `json:"amount"` - } + var topup TopupReq err := json.NewDecoder(req.Body).Decode(&topup) if err != nil { log.Printf("Can't parse topup: %v", err) @@ -30,20 +24,16 @@ func (a *api) AddTopup(adminNum int, w http.ResponseWriter, req *http.Request) { return } - transaction := Transaction{ - MemberNum: topup.Member, - Date: time.Now(), - Topup: &Topup{ - MemberNum: adminNum, - Comment: topup.Comment, - }, - Type: "topup", - Total: topup.Amount, - } - httpStatus, err := createTransaction(a.db, &transaction) + transaction, err := a.db.AddTopup(adminNum, topup.Member, topup.Amount, topup.Comment) if err != nil { - log.Println(err) - w.WriteHeader(httpStatus) + if errors.Is(err, db.ErrorNotFound) { + w.WriteHeader(http.StatusNotAcceptable) + } else if errors.Is(err, db.ErrorInvalidRequest) { + w.WriteHeader(http.StatusBadRequest) + } else { + w.WriteHeader(http.StatusInternalServerError) + log.Printf("Can't create topup: %v\n%v", err, transaction) + } return } diff --git a/api/topup_test.go b/api/topup_test.go index 83949566a6250c28dffafa88f697d404e6a3b8a2..3d2eb2ba51afbfe9a4624ea0b8af237a16598833 100644 --- a/api/topup_test.go +++ b/api/topup_test.go @@ -3,6 +3,8 @@ package api import ( "net/http" "testing" + + "0xacab.org/meskio/cicer/api/db" ) func TestTopupAddListMine(t *testing.T) { @@ -20,7 +22,7 @@ func TestTopupAddListMine(t *testing.T) { if resp.StatusCode != http.StatusCreated { t.Fatal("Can't create topup:", resp.Status) } - var transactions []Transaction + var transactions []db.Transaction resp = tapi.do("GET", "/transaction/mine", nil, &transactions) if resp.StatusCode != http.StatusOK { t.Fatal("Can't get transactions:", resp.Status) @@ -39,7 +41,7 @@ func TestTopupAddListMine(t *testing.T) { t.Error("Wrong topup comment:", transactions[0].Topup.Comment) } - var member Member + var member db.Member resp = tapi.do("GET", "/member/me", nil, &member) if resp.StatusCode != http.StatusOK { t.Error("Can't find the member:", resp.Status) diff --git a/api/transaction.go b/api/transaction.go index 6addd9dcba4a66ca3e9de154f7966409fb870efe..bd5d2c461aa80e10c13e10406e823f67e5069206 100644 --- a/api/transaction.go +++ b/api/transaction.go @@ -3,38 +3,16 @@ package api import ( "encoding/json" "errors" - "fmt" "log" "net/http" "strconv" - "time" + "0xacab.org/meskio/cicer/api/db" "github.com/gorilla/mux" - "gorm.io/gorm" - "gorm.io/gorm/clause" ) -type Transaction struct { - gorm.Model - MemberNum int `json:"-" gorm:"column:member"` - Member *Member `json:"member,omitempty" gorm:"foreignKey:MemberNum;references:Num"` - Date time.Time `json:"date"` - Total int `json:"total"` - Type string `json:"type"` - - Purchase []Purchase `json:"purchase,omitempty"` - Topup *Topup `json:"topup,omitempty"` - OrderPurchase []OrderPurchase `json:"order_purchase,omitempty" gorm:"foreignKey:TransactionID"` - Order *Order `json:"order,omitempty"` - OrderID *uint `json:"-"` - Refund *Order `json:"refund,omitempty" gorm:"foreignKey:TransactionID"` -} - func (a *api) ListTransactions(w http.ResponseWriter, req *http.Request) { - var transactions []Transaction - err := a.transactionQuery(). - Order("date desc"). - Find(&transactions).Error + transactions, err := a.db.ListTransactions() if err != nil { log.Printf("Can't list transactions: %v", err) w.WriteHeader(http.StatusInternalServerError) @@ -51,11 +29,10 @@ func (a *api) ListTransactions(w http.ResponseWriter, req *http.Request) { func (a *api) GetTransaction(num int, role string, w http.ResponseWriter, req *http.Request) { vars := mux.Vars(req) - var transaction Transaction - err := a.transactionQuery(). - First(&transaction, vars["id"]).Error + id, _ := strconv.Atoi(vars["id"]) + transaction, err := a.db.GetTransaction(id) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { + if errors.Is(err, db.ErrorNotFound) { w.WriteHeader(http.StatusNotFound) return } @@ -86,11 +63,7 @@ func (a *api) GetMemberTransactions(w http.ResponseWriter, req *http.Request) { } func (a *api) getTransactionsByMember(num int, w http.ResponseWriter, req *http.Request) { - var transactions []Transaction - err := a.transactionQuery(). - Where("member = ?", num). - Order("date desc"). - Find(&transactions).Error + transactions, err := a.db.TransactionByMember(num) if err != nil { log.Printf("Can't list transactions: %v", err) w.WriteHeader(http.StatusInternalServerError) @@ -104,41 +77,3 @@ func (a *api) getTransactionsByMember(num int, w http.ResponseWriter, req *http. w.WriteHeader(http.StatusInternalServerError) } } - -func createTransaction(db *gorm.DB, transaction *Transaction) (httpStatus int, err error) { - httpStatus = http.StatusOK - err = db.Transaction(func(tx *gorm.DB) error { - var member Member - err := tx.Where("num = ?", transaction.MemberNum).Find(&member).Error - if err != nil { - httpStatus = http.StatusNotAcceptable - return fmt.Errorf("Can't find member %d: %v", transaction.MemberNum, err) - } - if member.Balance < -transaction.Total { - httpStatus = http.StatusBadRequest - return fmt.Errorf("Member %d don't have enough money (%d-%d)", member.Num, member.Balance, transaction.Total) - } - err = tx.Model(&Member{}). - Where("num = ?", transaction.MemberNum). - Update("balance", gorm.Expr("balance + ?", transaction.Total)).Error - if err != nil { - httpStatus = http.StatusNotAcceptable - fmt.Errorf("Can't update update member balance %d-%d: %v", member.Num, transaction.Total, err) - } - - err = tx.Create(&transaction).Error - if err != nil { - httpStatus = http.StatusInternalServerError - return fmt.Errorf("Can't create transaction: %v\n%v", err, transaction) - } - return nil - }) - return -} - -func (a *api) transactionQuery() *gorm.DB { - return a.db.Preload("Purchase.Product"). - Preload("Order.Products"). - Preload("OrderPurchase.Product"). - Preload(clause.Associations) -}