diff --git a/api/auth.go b/api/auth.go index b6881e740b6c7cb2c2a41e8ae7a947a8227ca860..b91f4c26a377fb3e2a58fd32cfc67a4da4e36b61 100644 --- a/api/auth.go +++ b/api/auth.go @@ -34,8 +34,7 @@ func (a *api) SignIn(w http.ResponseWriter, req *http.Request) { return } - hash := hashPass(c.Password, member.Salt) - if subtle.ConstantTimeCompare(hash, member.PassHash) == 0 { + if !passwordValid(c.Password, member) { log.Printf("Invalid pass for %s", c.Login) w.WriteHeader(http.StatusBadRequest) return @@ -234,6 +233,11 @@ func newHashPass(password string) (hash []byte, salt []byte, err error) { return } +func passwordValid(password string, member Member) bool { + hash := hashPass(password, member.Salt) + return subtle.ConstantTimeCompare(hash, member.PassHash) == 1 +} + func hashPass(password string, salt []byte) []byte { const ( time = 1 diff --git a/api/member.go b/api/member.go index 2a72b31944fa436bef611f86bc5ef70acc6303ec..dcf871a198645570758590f5164364b97ba8f583 100644 --- a/api/member.go +++ b/api/member.go @@ -11,6 +11,10 @@ import ( "gorm.io/gorm" ) +var ( + ErroBadPassword = errors.New("Bad password") +) + type Member struct { gorm.Model `json:"-"` Num int `json:"num" gorm:"unique;index"` @@ -26,7 +30,8 @@ type Member struct { type MemberReq struct { Member - Password string `json:"password"` + OldPassword string `json:"old_password"` + Password string `json:"password"` } func (a *api) AddMember(w http.ResponseWriter, req *http.Request) { @@ -143,7 +148,7 @@ func (a *api) UpdateMember(w http.ResponseWriter, req *http.Request) { return } - m, err := a.updateMember(num, member) + m, err := a.updateMember(num, member, false) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { w.WriteHeader(http.StatusNotFound) @@ -175,14 +180,16 @@ func (a *api) UpdateMemberMe(num int, w http.ResponseWriter, req *http.Request) member.Num = 0 member.Balance = -1 - m, err := a.updateMember(num, member) + m, err := a.updateMember(num, member, true) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { w.WriteHeader(http.StatusNotFound) - return + } else if errors.Is(err, ErroBadPassword) { + w.WriteHeader(http.StatusBadRequest) + } else { + log.Printf("Can't update member %d: %v", num, err) + w.WriteHeader(http.StatusInternalServerError) } - log.Printf("Can't update member %d: %v", num, err) - w.WriteHeader(http.StatusInternalServerError) return } @@ -196,12 +203,15 @@ func (a *api) UpdateMemberMe(num int, w http.ResponseWriter, req *http.Request) } } -func (a *api) updateMember(num int, member MemberReq) (Member, error) { +func (a *api) updateMember(num int, member MemberReq, checkPass bool) (Member, error) { var dbMember Member err := a.db.Where("num = ?", num).First(&dbMember).Error if err != nil { return dbMember, err } + if checkPass && !passwordValid(member.OldPassword, dbMember) { + return dbMember, ErroBadPassword + } if member.Num != 0 { dbMember.Num = member.Num diff --git a/api/member_test.go b/api/member_test.go index 7758fa48f706b978560a15cb2456d0ef0a2ab755..f374a2b6b56db91f759de5b50d6683a2538ab19c 100644 --- a/api/member_test.go +++ b/api/member_test.go @@ -110,7 +110,14 @@ func TestMemberUpdateMe(t *testing.T) { member := testMember member.Password = "foobar" member.Email = "other@example.com" + member.OldPassword = "not my password" resp := tapi.do("PUT", "/member/me", member, nil) + if resp.StatusCode != http.StatusBadRequest { + t.Error("Did accept an invalid password:", resp.Status) + } + + member.OldPassword = testMember.Password + resp = tapi.do("PUT", "/member/me", member, nil) if resp.StatusCode != http.StatusAccepted { t.Fatal("Can't update member:", resp.Status) }