Delete expired attachments based on mod time instead of DB entry to avoid races

This commit is contained in:
Philipp Heckel 2022-07-08 10:00:04 -04:00
parent 3e53d8a2c7
commit 10a9aca2a1
6 changed files with 65 additions and 43 deletions

View File

@ -40,6 +40,7 @@ Thank you to [@wunter8](https://github.com/wunter8) for proactively picking up s
* `ntfy user` commands don't work with `auth_file` but works with `auth-file` ([#344](https://github.com/binwiederhier/ntfy/issues/344), thanks to [@Histalek](https://github.com/Histalek) for reporting) * `ntfy user` commands don't work with `auth_file` but works with `auth-file` ([#344](https://github.com/binwiederhier/ntfy/issues/344), thanks to [@Histalek](https://github.com/Histalek) for reporting)
* Ignore new draft HTTP `Priority` header ([#351](https://github.com/binwiederhier/ntfy/issues/351), thanks to [@ksurl](https://github.com/ksurl) for reporting) * Ignore new draft HTTP `Priority` header ([#351](https://github.com/binwiederhier/ntfy/issues/351), thanks to [@ksurl](https://github.com/ksurl) for reporting)
* Delete expired attachments based on mod time instead of DB entry to avoid races (no ticket)
**Documentation:** **Documentation:**

View File

@ -2,16 +2,18 @@ package server
import ( import (
"errors" "errors"
"fmt"
"heckel.io/ntfy/util" "heckel.io/ntfy/util"
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
"regexp" "regexp"
"sync" "sync"
"time"
) )
var ( var (
fileIDRegex = regexp.MustCompile(`^[-_A-Za-z0-9]+$`) fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, messageIDLength))
errInvalidFileID = errors.New("invalid file ID") errInvalidFileID = errors.New("invalid file ID")
errFileExists = errors.New("file exists") errFileExists = errors.New("file exists")
) )
@ -88,6 +90,25 @@ func (c *fileCache) Remove(ids ...string) error {
return nil return nil
} }
// Expired returns a list of file IDs for expired files
func (c *fileCache) Expired(olderThan time.Time) ([]string, error) {
entries, err := os.ReadDir(c.dir)
if err != nil {
return nil, err
}
var ids []string
for _, e := range entries {
info, err := e.Info()
if err != nil {
continue
}
if info.ModTime().Before(olderThan) && fileIDRegex.MatchString(e.Name()) {
ids = append(ids, e.Name())
}
}
return ids, nil
}
func (c *fileCache) Size() int64 { func (c *fileCache) Size() int64 {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()

View File

@ -8,6 +8,7 @@ import (
"os" "os"
"strings" "strings"
"testing" "testing"
"time"
) )
var ( var (
@ -16,10 +17,10 @@ var (
func TestFileCache_Write_Success(t *testing.T) { func TestFileCache_Write_Success(t *testing.T) {
dir, c := newTestFileCache(t) dir, c := newTestFileCache(t)
size, err := c.Write("abc", strings.NewReader("normal file"), util.NewFixedLimiter(999)) size, err := c.Write("abcdefghijkl", strings.NewReader("normal file"), util.NewFixedLimiter(999))
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, int64(11), size) require.Equal(t, int64(11), size)
require.Equal(t, "normal file", readFile(t, dir+"/abc")) require.Equal(t, "normal file", readFile(t, dir+"/abcdefghijkl"))
require.Equal(t, int64(11), c.Size()) require.Equal(t, int64(11), c.Size())
require.Equal(t, int64(10229), c.Remaining()) require.Equal(t, int64(10229), c.Remaining())
} }
@ -27,18 +28,18 @@ func TestFileCache_Write_Success(t *testing.T) {
func TestFileCache_Write_Remove_Success(t *testing.T) { func TestFileCache_Write_Remove_Success(t *testing.T) {
dir, c := newTestFileCache(t) // max = 10k (10240), each = 1k (1024) dir, c := newTestFileCache(t) // max = 10k (10240), each = 1k (1024)
for i := 0; i < 10; i++ { // 10x999 = 9990 for i := 0; i < 10; i++ { // 10x999 = 9990
size, err := c.Write(fmt.Sprintf("abc%d", i), bytes.NewReader(make([]byte, 999))) size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999)))
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, int64(999), size) require.Equal(t, int64(999), size)
} }
require.Equal(t, int64(9990), c.Size()) require.Equal(t, int64(9990), c.Size())
require.Equal(t, int64(250), c.Remaining()) require.Equal(t, int64(250), c.Remaining())
require.FileExists(t, dir+"/abc1") require.FileExists(t, dir+"/abcdefghijk1")
require.FileExists(t, dir+"/abc5") require.FileExists(t, dir+"/abcdefghijk5")
require.Nil(t, c.Remove("abc1", "abc5")) require.Nil(t, c.Remove("abcdefghijk1", "abcdefghijk5"))
require.NoFileExists(t, dir+"/abc1") require.NoFileExists(t, dir+"/abcdefghijk1")
require.NoFileExists(t, dir+"/abc5") require.NoFileExists(t, dir+"/abcdefghijk5")
require.Equal(t, int64(7992), c.Size()) require.Equal(t, int64(7992), c.Size())
require.Equal(t, int64(2248), c.Remaining()) require.Equal(t, int64(2248), c.Remaining())
} }
@ -46,27 +47,50 @@ func TestFileCache_Write_Remove_Success(t *testing.T) {
func TestFileCache_Write_FailedTotalSizeLimit(t *testing.T) { func TestFileCache_Write_FailedTotalSizeLimit(t *testing.T) {
dir, c := newTestFileCache(t) dir, c := newTestFileCache(t)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
size, err := c.Write(fmt.Sprintf("abc%d", i), bytes.NewReader(oneKilobyteArray)) size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray))
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, int64(1024), size) require.Equal(t, int64(1024), size)
} }
_, err := c.Write("abc11", bytes.NewReader(oneKilobyteArray)) _, err := c.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray))
require.Equal(t, util.ErrLimitReached, err) require.Equal(t, util.ErrLimitReached, err)
require.NoFileExists(t, dir+"/abc11") require.NoFileExists(t, dir+"/abcdefghijkX")
} }
func TestFileCache_Write_FailedFileSizeLimit(t *testing.T) { func TestFileCache_Write_FailedFileSizeLimit(t *testing.T) {
dir, c := newTestFileCache(t) dir, c := newTestFileCache(t)
_, err := c.Write("abc", bytes.NewReader(make([]byte, 1025))) _, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1025)))
require.Equal(t, util.ErrLimitReached, err) require.Equal(t, util.ErrLimitReached, err)
require.NoFileExists(t, dir+"/abc") require.NoFileExists(t, dir+"/abcdefghijkl")
} }
func TestFileCache_Write_FailedAdditionalLimiter(t *testing.T) { func TestFileCache_Write_FailedAdditionalLimiter(t *testing.T) {
dir, c := newTestFileCache(t) dir, c := newTestFileCache(t)
_, err := c.Write("abc", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000)) _, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000))
require.Equal(t, util.ErrLimitReached, err) require.Equal(t, util.ErrLimitReached, err)
require.NoFileExists(t, dir+"/abc") require.NoFileExists(t, dir+"/abcdefghijkl")
}
func TestFileCache_RemoveExpired(t *testing.T) {
dir, c := newTestFileCache(t)
_, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)))
require.Nil(t, err)
_, err = c.Write("notdeleted12", bytes.NewReader(make([]byte, 1001)))
require.Nil(t, err)
modTime := time.Now().Add(-1 * 4 * time.Hour)
require.Nil(t, os.Chtimes(dir+"/abcdefghijkl", modTime, modTime))
olderThan := time.Now().Add(-1 * 3 * time.Hour)
ids, err := c.Expired(olderThan)
require.Nil(t, err)
require.Equal(t, []string{"abcdefghijkl"}, ids)
require.Nil(t, c.Remove(ids...))
require.NoFileExists(t, dir+"/abcdefghijkl")
require.FileExists(t, dir+"/notdeleted12")
ids, err = c.Expired(olderThan)
require.Nil(t, err)
require.Empty(t, ids)
} }
func newTestFileCache(t *testing.T) (dir string, cache *fileCache) { func newTestFileCache(t *testing.T) (dir string, cache *fileCache) {

View File

@ -85,7 +85,6 @@ const (
selectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic` selectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic`
selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic` selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic`
selectAttachmentsSizeQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?` selectAttachmentsSizeQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?`
selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires < ?`
) )
// Schema management queries // Schema management queries
@ -409,26 +408,6 @@ func (c *messageCache) AttachmentBytesUsed(sender string) (int64, error) {
return size, nil return size, nil
} }
func (c *messageCache) AttachmentsExpired() ([]string, error) {
rows, err := c.db.Query(selectAttachmentsExpiredQuery, time.Now().Unix())
if err != nil {
return nil, err
}
defer rows.Close()
ids := make([]string, 0)
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, err
}
ids = append(ids, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
return ids, nil
}
func readMessages(rows *sql.Rows) ([]*message, error) { func readMessages(rows *sql.Rows) ([]*message, error) {
defer rows.Close() defer rows.Close()
messages := make([]*message, 0) messages := make([]*message, 0)

View File

@ -344,10 +344,6 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
size, err = c.AttachmentBytesUsed("5.6.7.8") size, err = c.AttachmentBytesUsed("5.6.7.8")
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, int64(0), size) require.Equal(t, int64(0), size)
ids, err := c.AttachmentsExpired()
require.Nil(t, err)
require.Equal(t, []string{"m1"}, ids)
} }
func TestSqliteCache_Migration_From0(t *testing.T) { func TestSqliteCache_Migration_From0(t *testing.T) {

View File

@ -1116,8 +1116,9 @@ func (s *Server) updateStatsAndPrune() {
log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors) log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors)
// Delete expired attachments // Delete expired attachments
if s.fileCache != nil { if s.fileCache != nil && s.config.AttachmentExpiryDuration > 0 {
ids, err := s.messageCache.AttachmentsExpired() olderThan := time.Now().Add(-1 * s.config.AttachmentExpiryDuration)
ids, err := s.fileCache.Expired(olderThan)
if err != nil { if err != nil {
log.Warn("Error retrieving expired attachments: %s", err.Error()) log.Warn("Error retrieving expired attachments: %s", err.Error())
} else if len(ids) > 0 { } else if len(ids) > 0 {