diff --git a/api/auth.go b/api/auth.go index a905c21336ce34ab71d4769c32fa5243e635b466..4f880373eeeb965b9019760a41ffadbc9436ca56 100644 --- a/api/auth.go +++ b/api/auth.go @@ -277,7 +277,6 @@ func (a *api) authOrderNum(fn func(int, http.ResponseWriter, *http.Request)) fun } role, ok := claims["role"].(string) if !ok || !roleOrder(role) { - log.Println(role) w.WriteHeader(http.StatusUnauthorized) return } diff --git a/api/db/order.go b/api/db/order.go index ece51efc2b879ad2488c1a59a00168b7755bf354..88297cff9385023c32bf29deae37e78ba9a8842f 100644 --- a/api/db/order.go +++ b/api/db/order.go @@ -35,7 +35,7 @@ type OrderPurchase struct { gorm.Model `json:"-"` TransactionID uint `json:"-"` OrderProductID uint `json:"order_product_id"` - OrderProduct *OrderProduct `json:"order_product"` + OrderProduct *OrderProduct `json:"order_product" gorm:"constraint:OnDelete:CASCADE"` Amount int `json:"amount"` } @@ -55,6 +55,7 @@ func (d *DB) ListOrderPicks(num int) (orders []Order, err error) { Table("(?) as orders", d.db.Model(&Order{}).Order("deadline desc")). Group("name").Order("member_selected desc, deadline desc").Limit(15). Preload(clause.Associations). + Preload("Products.Product"). Find(&orders).Error return } @@ -110,32 +111,10 @@ func (d *DB) UpdateOrder(memberNum int, id int, order *Order) error { dbOrder.Deadline = order.Deadline return d.db.Transaction(func(tx *gorm.DB) error { - for _, product := range order.Products { - var err error - dbProduct := findOrderProduct(product.ProductCode, dbOrder.Products) - if dbProduct != nil { - dbProduct.Price = product.Price - err = tx.Save(&dbProduct).Error - } else { - err = tx.Create(&product).Error - } - if err != nil { - return err - } - } - for _, product := range dbOrder.Products { - if findOrderProduct(product.ProductCode, order.Products) == nil { - err = tx.Delete(&product).Error - if err != nil { - return err - } - } - } - totalSum := 0 for i, t := range dbOrder.Transactions { var transaction Transaction - err := tx.Preload("OrderPurchase.OrderProduct").First(&transaction, id).Error + err = tx.Preload("OrderPurchase.OrderProduct").First(&transaction, id).Error if err != nil { return err } @@ -147,17 +126,56 @@ func (d *DB) UpdateOrder(memberNum int, id int, order *Order) error { totalSum += total } - if dbOrder.TransactionID != nil { - updateOrderTransaction(tx, int(*dbOrder.TransactionID), totalSum, &dbOrder) - } - err := tx.Save(&dbOrder).Error + products, err := updateOrderProducts(tx, *order, dbOrder) if err != nil { return err } - return nil + + if dbOrder.TransactionID != nil { + err = updateOrderTransaction(tx, int(*dbOrder.TransactionID), totalSum, &dbOrder) + if err != nil { + return err + } + } + + dbOrder.Products = products + dbOrder.Transactions = []Transaction{} + return tx.Save(&dbOrder).Error }) } +func updateOrderProducts(tx *gorm.DB, order Order, dbOrder Order) (products []OrderProduct, err error) { + for _, product := range order.Products { + dbProduct := findOrderProduct(product.ProductCode, dbOrder.Products) + if dbProduct != nil { + dbProduct.Price = product.Price + products = append(products, *dbProduct) + err = tx.Save(&dbProduct).Error + } else { + product.OrderID = uint(dbOrder.ID) + err = tx.Create(&product).Error + products = append(products, product) + } + if err != nil { + return + } + } + for _, product := range dbOrder.Products { + if findOrderProduct(product.ProductCode, order.Products) == nil { + err = tx.Where("order_product_id = ?", product.ID).Delete(&OrderPurchase{}).Error + if err != nil { + return + } + err = tx.Delete(&product).Error + if err != nil { + return + } + + } + } + return +} + func updateOrderTransaction(tx *gorm.DB, id int, total int, order *Order) error { var transaction Transaction err := tx.First(&transaction, id).Error diff --git a/api/order_test.go b/api/order_test.go index 1130f13e254713ab6bd3d25497e6ae752167eeb2..dee0d96877aca513e14dfd1c761dfbb629868c14 100644 --- a/api/order_test.go +++ b/api/order_test.go @@ -518,6 +518,21 @@ func TestOrderUpdateProduct(t *testing.T) { tapi.t.Fatal("Can't update order:", resp.Status) } + var orderResponse OrderGetResponse + resp = tapi.do("GET", fmt.Sprintf("/order/%d", orders[0].ID), nil, &orderResponse) + if resp.StatusCode != http.StatusOK { + t.Fatal("Can't get order:", resp.Status) + } + if len(orderResponse.Order.Products) != 1 { + t.Fatal("Wrong len of products:", orderResponse.Order.Products) + } + if orderResponse.Order.Products[0].ProductCode != testProduct2.Code { + t.Fatal("Wrong product code:", orderResponse.Order.Products) + } + if orderResponse.Order.Products[0].Price != testProduct2.Price { + t.Fatal("Wrong product price:", orderResponse.Order.Products) + } + var transactions []db.Transaction resp = tapi.do("GET", "/transaction/mine", nil, &transactions) if resp.StatusCode != http.StatusOK { @@ -529,6 +544,9 @@ func TestOrderUpdateProduct(t *testing.T) { if transactions[0].Total != 0 { t.Error("Wrong total", transactions[0].Total) } + if len(transactions[0].OrderPurchase) != 0 { + t.Error("Wrong purchases", transactions[0].OrderPurchase) + } } func TestOrderUpdateReactivate(t *testing.T) {