Newer
Older
package db
import (
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"errors"
"log"
"time"
"golang.org/x/crypto/argon2"
"gorm.io/gorm"
)
const (
timeExpireResetToken = 7 * 24 * time.Hour
CreatedAt time.Time `json:"-"`
UpdatedAt time.Time `json:"-"`
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
Num int `json:"num" gorm:"primaryKey"`
Login *string `json:"login" gorm:"unique;index"`
Name string `json:"name"`
Email string `json:"email"`
Phone string `json:"phone"`
Balance int `json:"balance"`
Role string `json:"role"`
PassHash []byte `json:"-"`
Salt []byte `json:"-"`
}
type PasswordReset struct {
gorm.Model
Token string `gorm:"unique;index"`
MemberNum int `gorm:"column:member"`
Member *Member `gorm:"foreignKey:MemberNum;references:Num"`
}
type MemberReq struct {
Member
OldPassword string `json:"old_password"`
Password string `json:"password"`
}
func (d DB) AddMember(memberReq *MemberReq) (member Member, err error) {
member.Num = memberReq.Num
if memberReq.Login != nil {
member.Login = cleanLogin(*memberReq.Login)
}
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
member.Phone = memberReq.Phone
member.Balance = memberReq.Balance
member.Role = memberReq.Role
member.PassHash, member.Salt, err = newHashPass(memberReq.Password)
if err != nil {
return
}
err = d.db.Create(&member).Error
return
}
func (d DB) ListMembers() (members []Member, err error) {
err = d.db.Find(&members).Error
return
}
func (d DB) GetMember(num int) (member Member, err error) {
err = d.db.Where("num = ?", num).First(&member).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
err = ErrorNotFound
}
return
}
func (d DB) DeleteMember(num int) error {
return d.db.Where("num = ?", num).Delete(&Member{}).Error
}
func (d DB) UpdateMember(num int, member MemberReq, checkPass bool) (Member, error) {
var dbMember Member
err := d.db.Where("num = ?", num).First(&dbMember).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
err = ErrorNotFound
}
return dbMember, err
}
if checkPass && !passwordValid(member.OldPassword, dbMember) {
return dbMember, ErrorBadPassword
}
if member.Num != 0 {
dbMember.Num = member.Num
}
}
if member.Name != "" {
dbMember.Name = member.Name
}
if member.Email != "" {
if member.Phone != "" {
dbMember.Phone = member.Phone
}
if member.Role != "" {
dbMember.Role = member.Role
}
if member.Password != "" {
dbMember.PassHash, dbMember.Salt, err = newHashPass(member.Password)
if err != nil {
return dbMember, err
}
}
err = d.db.Save(&dbMember).Error
return dbMember, err
}
func (d DB) Login(login, password string) (member Member, err error) {
cleanedLogin := cleanLogin(login)
err = d.db.Where("email = ?", cleanedLogin).
err = d.db.Where("login = ?", cleanedLogin).
First(&member).Error
if err != nil {
return
}
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
}
if !passwordValid(password, member) {
err = ErrorBadPassword
}
return
}
func (d DB) NewPasswordReset(email string) (member Member, token string, err error) {
err = d.db.Where("email = ?", email).First(&member).Error
if err != nil {
log.Printf("Can't locate user %s: %v", email, err)
err = ErrorNotFound
return
}
tokenBytes := make([]byte, 15)
_, err = rand.Read(tokenBytes)
if err != nil {
log.Printf("Can't generate a random token for password reset: %v", err)
return
}
token = base64.URLEncoding.EncodeToString(tokenBytes)
passwordReset := PasswordReset{
Token: token,
MemberNum: member.Num,
}
err = d.db.Create(&passwordReset).Error
return
}
func (d *DB) ResetPassword(token, password, login string) error {
passwordReset, err := d.GetPasswordReset(token)
var member Member
member.PassHash, member.Salt, err = newHashPass(password)
return d.db.Transaction(func(tx *gorm.DB) error {
if err != nil {
return err
}
return tx.Delete(passwordReset).Error
})
}
func (d *DB) GetPasswordReset(token string) (passwordReset PasswordReset, err error) {
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
err = d.db.Where("token = ?", token).
Preload("Member").
First(&passwordReset).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
err = ErrorNotFound
}
return
}
func (d *DB) CleanPasswordReset() {
t := time.Now().Add(timeExpireResetToken)
res := d.db.Where("created_at < ?", true, t).
Delete(&PasswordReset{})
if res.Error != nil {
log.Println("Error deleting old reset tokens:", res.Error)
} else if res.RowsAffected != 0 {
log.Println("Deleted", res.RowsAffected, "password reset tokens")
}
}
func newHashPass(password string) (hash []byte, salt []byte, err error) {
salt = make([]byte, 16)
_, err = rand.Read(salt)
if err != nil {
return
}
hash = hashPass(password, salt)
return
}
func passwordValid(password string, member Member) bool {
hash := hashPass(password, member.Salt)
return subtle.ConstantTimeCompare(hash, member.PassHash) == 1
}
func hashPass(password string, salt []byte) []byte {
const (
time = 1
memory = 64 * 1024
threads = 2
keyLen = 32
)
return argon2.IDKey([]byte(password), salt, time, memory, threads, keyLen)
}