1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2024-12-14 11:47:33 +00:00

(Hopefully) remove statsQueue races

This commit is contained in:
binwiederhier 2023-01-27 09:59:16 -05:00
parent 22c66203a0
commit 9e9caee639
4 changed files with 23 additions and 24 deletions

View file

@ -599,7 +599,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
} }
v.IncrementMessages() v.IncrementMessages()
if s.userManager != nil && v.user != nil { if s.userManager != nil && v.user != nil {
s.userManager.EnqueueStats(v.user) // FIXME this makes no sense for tier-less users s.userManager.EnqueueStats(v.user.ID, v.Stats()) // FIXME this makes no sense for tier-less users
} }
s.mu.Lock() s.mu.Lock()
s.messages++ s.messages++

View file

@ -232,17 +232,20 @@ func (v *visitor) IncrementMessages() {
v.mu.Lock() v.mu.Lock()
defer v.mu.Unlock() defer v.mu.Unlock()
v.messages++ v.messages++
if v.user != nil {
v.user.Stats.Messages = v.messages
}
} }
func (v *visitor) IncrementEmails() { func (v *visitor) IncrementEmails() {
v.mu.Lock() v.mu.Lock()
defer v.mu.Unlock() defer v.mu.Unlock()
v.emails++ v.emails++
if v.user != nil { }
v.user.Stats.Emails = v.emails
func (v *visitor) Stats() *user.Stats {
v.mu.Lock()
defer v.mu.Unlock()
return &user.Stats{
Messages: v.messages,
Emails: v.emails,
} }
} }
@ -254,10 +257,6 @@ func (v *visitor) ResetStats() {
if v.messagesLimiter != nil { if v.messagesLimiter != nil {
v.messagesLimiter.Reset() v.messagesLimiter.Reset()
} }
if v.user != nil {
v.user.Stats.Messages = 0
v.user.Stats.Emails = 0
}
} }
// SetUser sets the visitors user to the given value // SetUser sets the visitors user to the given value

View file

@ -292,8 +292,8 @@ const (
// in a SQLite database. // in a SQLite database.
type Manager struct { type Manager struct {
db *sql.DB db *sql.DB
defaultAccess Permission // Default permission if no ACL matches defaultAccess Permission // Default permission if no ACL matches
statsQueue map[string]*User // Username -> User, for "unimportant" user updates statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats)
mu sync.Mutex mu sync.Mutex
} }
@ -319,7 +319,7 @@ func newManager(filename, startupQueries string, defaultAccess Permission, stats
manager := &Manager{ manager := &Manager{
db: db, db: db,
defaultAccess: defaultAccess, defaultAccess: defaultAccess,
statsQueue: make(map[string]*User), statsQueue: make(map[string]*Stats),
} }
go manager.userStatsQueueWriter(statsWriterInterval) go manager.userStatsQueueWriter(statsWriterInterval)
return manager, nil return manager, nil
@ -464,16 +464,16 @@ func (a *Manager) ResetStats() error {
if _, err := a.db.Exec(updateUserStatsResetAllQuery); err != nil { if _, err := a.db.Exec(updateUserStatsResetAllQuery); err != nil {
return err return err
} }
a.statsQueue = make(map[string]*User) a.statsQueue = make(map[string]*Stats)
return nil return nil
} }
// EnqueueStats adds the user to a queue which writes out user stats (messages, emails, ..) in // EnqueueStats adds the user to a queue which writes out user stats (messages, emails, ..) in
// batches at a regular interval // batches at a regular interval
func (a *Manager) EnqueueStats(user *User) { func (a *Manager) EnqueueStats(userID string, stats *Stats) {
a.mu.Lock() a.mu.Lock()
defer a.mu.Unlock() defer a.mu.Unlock()
a.statsQueue[user.ID] = user a.statsQueue[userID] = stats
} }
func (a *Manager) userStatsQueueWriter(interval time.Duration) { func (a *Manager) userStatsQueueWriter(interval time.Duration) {
@ -493,7 +493,7 @@ func (a *Manager) writeUserStatsQueue() error {
return nil return nil
} }
statsQueue := a.statsQueue statsQueue := a.statsQueue
a.statsQueue = make(map[string]*User) a.statsQueue = make(map[string]*Stats)
a.mu.Unlock() a.mu.Unlock()
tx, err := a.db.Begin() tx, err := a.db.Begin()
if err != nil { if err != nil {
@ -501,9 +501,9 @@ func (a *Manager) writeUserStatsQueue() error {
} }
defer tx.Rollback() defer tx.Rollback()
log.Debug("User Manager: Writing user stats queue for %d user(s)", len(statsQueue)) log.Debug("User Manager: Writing user stats queue for %d user(s)", len(statsQueue))
for userID, u := range statsQueue { for userID, update := range statsQueue {
log.Trace("User Manager: Updating stats for user %s: messages=%d, emails=%d", userID, u.Stats.Messages, u.Stats.Emails) log.Trace("User Manager: Updating stats for user %s: messages=%d, emails=%d", userID, update.Messages, update.Emails)
if _, err := tx.Exec(updateUserStatsQuery, u.Stats.Messages, u.Stats.Emails, userID); err != nil { if _, err := tx.Exec(updateUserStatsQuery, update.Messages, update.Emails, userID); err != nil {
return err return err
} }
} }

View file

@ -554,10 +554,10 @@ func TestManager_EnqueueStats(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, int64(0), u.Stats.Messages) require.Equal(t, int64(0), u.Stats.Messages)
require.Equal(t, int64(0), u.Stats.Emails) require.Equal(t, int64(0), u.Stats.Emails)
a.EnqueueStats(u.ID, &Stats{
u.Stats.Messages = 11 Messages: 11,
u.Stats.Emails = 2 Emails: 2,
a.EnqueueStats(u) })
// Still no change, because it's queued asynchronously // Still no change, because it's queued asynchronously
u, err = a.User("ben") u, err = a.User("ben")