From 10a9aca2a1dc93c6e09954012d84d39e8319dcd2 Mon Sep 17 00:00:00 2001 From: Philipp Heckel Date: Fri, 8 Jul 2022 10:00:04 -0400 Subject: [PATCH] Delete expired attachments based on mod time instead of DB entry to avoid races --- docs/releases.md | 1 + server/file_cache.go | 23 ++++++++++++++- server/file_cache_test.go | 54 ++++++++++++++++++++++++++---------- server/message_cache.go | 21 -------------- server/message_cache_test.go | 4 --- server/server.go | 5 ++-- 6 files changed, 65 insertions(+), 43 deletions(-) diff --git a/docs/releases.md b/docs/releases.md index 98a07c7e..3ae44d05 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -40,6 +40,7 @@ Thank you to [@wunter8](https://github.com/wunter8) for proactively picking up s * `ntfy user` commands don't work with `auth_file` but works with `auth-file` ([#344](https://github.com/binwiederhier/ntfy/issues/344), thanks to [@Histalek](https://github.com/Histalek) for reporting) * Ignore new draft HTTP `Priority` header ([#351](https://github.com/binwiederhier/ntfy/issues/351), thanks to [@ksurl](https://github.com/ksurl) for reporting) +* Delete expired attachments based on mod time instead of DB entry to avoid races (no ticket) **Documentation:** diff --git a/server/file_cache.go b/server/file_cache.go index ad4961cc..88de935d 100644 --- a/server/file_cache.go +++ b/server/file_cache.go @@ -2,16 +2,18 @@ package server import ( "errors" + "fmt" "heckel.io/ntfy/util" "io" "os" "path/filepath" "regexp" "sync" + "time" ) var ( - fileIDRegex = regexp.MustCompile(`^[-_A-Za-z0-9]+$`) + fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, messageIDLength)) errInvalidFileID = errors.New("invalid file ID") errFileExists = errors.New("file exists") ) @@ -88,6 +90,25 @@ func (c *fileCache) Remove(ids ...string) error { return nil } +// Expired returns a list of file IDs for expired files +func (c *fileCache) Expired(olderThan time.Time) ([]string, error) { + entries, err := os.ReadDir(c.dir) + if err != nil { + return nil, err + } + var ids []string + for _, e := range entries { + info, err := e.Info() + if err != nil { + continue + } + if info.ModTime().Before(olderThan) && fileIDRegex.MatchString(e.Name()) { + ids = append(ids, e.Name()) + } + } + return ids, nil +} + func (c *fileCache) Size() int64 { c.mu.Lock() defer c.mu.Unlock() diff --git a/server/file_cache_test.go b/server/file_cache_test.go index 36d1d1a3..971cff1d 100644 --- a/server/file_cache_test.go +++ b/server/file_cache_test.go @@ -8,6 +8,7 @@ import ( "os" "strings" "testing" + "time" ) var ( @@ -16,10 +17,10 @@ var ( func TestFileCache_Write_Success(t *testing.T) { dir, c := newTestFileCache(t) - size, err := c.Write("abc", strings.NewReader("normal file"), util.NewFixedLimiter(999)) + size, err := c.Write("abcdefghijkl", strings.NewReader("normal file"), util.NewFixedLimiter(999)) require.Nil(t, err) require.Equal(t, int64(11), size) - require.Equal(t, "normal file", readFile(t, dir+"/abc")) + require.Equal(t, "normal file", readFile(t, dir+"/abcdefghijkl")) require.Equal(t, int64(11), c.Size()) require.Equal(t, int64(10229), c.Remaining()) } @@ -27,18 +28,18 @@ func TestFileCache_Write_Success(t *testing.T) { func TestFileCache_Write_Remove_Success(t *testing.T) { dir, c := newTestFileCache(t) // max = 10k (10240), each = 1k (1024) for i := 0; i < 10; i++ { // 10x999 = 9990 - size, err := c.Write(fmt.Sprintf("abc%d", i), bytes.NewReader(make([]byte, 999))) + size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999))) require.Nil(t, err) require.Equal(t, int64(999), size) } require.Equal(t, int64(9990), c.Size()) require.Equal(t, int64(250), c.Remaining()) - require.FileExists(t, dir+"/abc1") - require.FileExists(t, dir+"/abc5") + require.FileExists(t, dir+"/abcdefghijk1") + require.FileExists(t, dir+"/abcdefghijk5") - require.Nil(t, c.Remove("abc1", "abc5")) - require.NoFileExists(t, dir+"/abc1") - require.NoFileExists(t, dir+"/abc5") + require.Nil(t, c.Remove("abcdefghijk1", "abcdefghijk5")) + require.NoFileExists(t, dir+"/abcdefghijk1") + require.NoFileExists(t, dir+"/abcdefghijk5") require.Equal(t, int64(7992), c.Size()) require.Equal(t, int64(2248), c.Remaining()) } @@ -46,27 +47,50 @@ func TestFileCache_Write_Remove_Success(t *testing.T) { func TestFileCache_Write_FailedTotalSizeLimit(t *testing.T) { dir, c := newTestFileCache(t) for i := 0; i < 10; i++ { - size, err := c.Write(fmt.Sprintf("abc%d", i), bytes.NewReader(oneKilobyteArray)) + size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray)) require.Nil(t, err) require.Equal(t, int64(1024), size) } - _, err := c.Write("abc11", bytes.NewReader(oneKilobyteArray)) + _, err := c.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray)) require.Equal(t, util.ErrLimitReached, err) - require.NoFileExists(t, dir+"/abc11") + require.NoFileExists(t, dir+"/abcdefghijkX") } func TestFileCache_Write_FailedFileSizeLimit(t *testing.T) { dir, c := newTestFileCache(t) - _, err := c.Write("abc", bytes.NewReader(make([]byte, 1025))) + _, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1025))) require.Equal(t, util.ErrLimitReached, err) - require.NoFileExists(t, dir+"/abc") + require.NoFileExists(t, dir+"/abcdefghijkl") } func TestFileCache_Write_FailedAdditionalLimiter(t *testing.T) { dir, c := newTestFileCache(t) - _, err := c.Write("abc", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000)) + _, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000)) require.Equal(t, util.ErrLimitReached, err) - require.NoFileExists(t, dir+"/abc") + require.NoFileExists(t, dir+"/abcdefghijkl") +} + +func TestFileCache_RemoveExpired(t *testing.T) { + dir, c := newTestFileCache(t) + _, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001))) + require.Nil(t, err) + _, err = c.Write("notdeleted12", bytes.NewReader(make([]byte, 1001))) + require.Nil(t, err) + + modTime := time.Now().Add(-1 * 4 * time.Hour) + require.Nil(t, os.Chtimes(dir+"/abcdefghijkl", modTime, modTime)) + + olderThan := time.Now().Add(-1 * 3 * time.Hour) + ids, err := c.Expired(olderThan) + require.Nil(t, err) + require.Equal(t, []string{"abcdefghijkl"}, ids) + require.Nil(t, c.Remove(ids...)) + require.NoFileExists(t, dir+"/abcdefghijkl") + require.FileExists(t, dir+"/notdeleted12") + + ids, err = c.Expired(olderThan) + require.Nil(t, err) + require.Empty(t, ids) } func newTestFileCache(t *testing.T) (dir string, cache *fileCache) { diff --git a/server/message_cache.go b/server/message_cache.go index f6fba96d..2e9c577e 100644 --- a/server/message_cache.go +++ b/server/message_cache.go @@ -85,7 +85,6 @@ const ( selectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic` selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic` selectAttachmentsSizeQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?` - selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires < ?` ) // Schema management queries @@ -409,26 +408,6 @@ func (c *messageCache) AttachmentBytesUsed(sender string) (int64, error) { return size, nil } -func (c *messageCache) AttachmentsExpired() ([]string, error) { - rows, err := c.db.Query(selectAttachmentsExpiredQuery, time.Now().Unix()) - if err != nil { - return nil, err - } - defer rows.Close() - ids := make([]string, 0) - for rows.Next() { - var id string - if err := rows.Scan(&id); err != nil { - return nil, err - } - ids = append(ids, id) - } - if err := rows.Err(); err != nil { - return nil, err - } - return ids, nil -} - func readMessages(rows *sql.Rows) ([]*message, error) { defer rows.Close() messages := make([]*message, 0) diff --git a/server/message_cache_test.go b/server/message_cache_test.go index b68fc330..23c080d4 100644 --- a/server/message_cache_test.go +++ b/server/message_cache_test.go @@ -344,10 +344,6 @@ func testCacheAttachments(t *testing.T, c *messageCache) { size, err = c.AttachmentBytesUsed("5.6.7.8") require.Nil(t, err) require.Equal(t, int64(0), size) - - ids, err := c.AttachmentsExpired() - require.Nil(t, err) - require.Equal(t, []string{"m1"}, ids) } func TestSqliteCache_Migration_From0(t *testing.T) { diff --git a/server/server.go b/server/server.go index ca0d6393..94f35801 100644 --- a/server/server.go +++ b/server/server.go @@ -1116,8 +1116,9 @@ func (s *Server) updateStatsAndPrune() { log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors) // Delete expired attachments - if s.fileCache != nil { - ids, err := s.messageCache.AttachmentsExpired() + if s.fileCache != nil && s.config.AttachmentExpiryDuration > 0 { + olderThan := time.Now().Add(-1 * s.config.AttachmentExpiryDuration) + ids, err := s.fileCache.Expired(olderThan) if err != nil { log.Warn("Error retrieving expired attachments: %s", err.Error()) } else if len(ids) > 0 {