From e7c19a2bad5c59b869ae3f72fb263accd49cbdb7 Mon Sep 17 00:00:00 2001 From: Philipp Heckel Date: Fri, 7 Jan 2022 15:15:33 +0100 Subject: [PATCH] Expire attachments properly --- server/cache.go | 1 + server/cache_mem.go | 14 +++++++++ server/cache_sqlite.go | 25 ++++++++++++++-- server/file_cache.go | 65 ++++++++++++++++++++++++++++++++++-------- server/server.go | 10 +++++++ 5 files changed, 101 insertions(+), 14 deletions(-) diff --git a/server/cache.go b/server/cache.go index 7532ff7f..89db72c4 100644 --- a/server/cache.go +++ b/server/cache.go @@ -21,4 +21,5 @@ type cache interface { Prune(olderThan time.Time) error MarkPublished(m *message) error AttachmentsSize(owner string) (int64, error) + AttachmentsExpired() ([]string, error) } diff --git a/server/cache_mem.go b/server/cache_mem.go index 91bcb38c..04e57be9 100644 --- a/server/cache_mem.go +++ b/server/cache_mem.go @@ -139,6 +139,20 @@ func (c *memCache) AttachmentsSize(owner string) (int64, error) { return size, nil } +func (c *memCache) AttachmentsExpired() ([]string, error) { + c.mu.Lock() + defer c.mu.Unlock() + ids := make([]string, 0) + for topic := range c.messages { + for _, m := range c.messages[topic] { + if m.Attachment != nil && m.Attachment.Expires > 0 && m.Attachment.Expires < time.Now().Unix() { + ids = append(ids, m.ID) + } + } + } + return ids, nil +} + func (c *memCache) pruneTopic(topic string, olderThan time.Time) { messages := make([]*message, 0) for _, m := range c.messages[topic] { diff --git a/server/cache_sqlite.go b/server/cache_sqlite.go index 4a52f281..c8d97735 100644 --- a/server/cache_sqlite.go +++ b/server/cache_sqlite.go @@ -60,7 +60,8 @@ const ( selectMessagesCountQuery = `SELECT COUNT(*) FROM messages` selectMessageCountForTopicQuery = `SELECT COUNT(*) FROM messages WHERE topic = ?` selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic` - selectAttachmentsSizeQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE attachment_owner = ?` + selectAttachmentsSizeQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE attachment_owner = ? AND attachment_expires >= ?` + selectAttachmentsExpiredQuery = `SELECT id FROM messages WHERE attachment_expires > 0 AND attachment_expires < ?` ) // Schema management queries @@ -234,7 +235,7 @@ func (c *sqliteCache) Prune(olderThan time.Time) error { } func (c *sqliteCache) AttachmentsSize(owner string) (int64, error) { - rows, err := c.db.Query(selectAttachmentsSizeQuery, owner) + rows, err := c.db.Query(selectAttachmentsSizeQuery, owner, time.Now().Unix()) if err != nil { return 0, err } @@ -251,6 +252,26 @@ func (c *sqliteCache) AttachmentsSize(owner string) (int64, error) { return size, nil } +func (c *sqliteCache) AttachmentsExpired() ([]string, error) { + rows, err := c.db.Query(selectAttachmentsExpiredQuery, time.Now().Unix()) + if err != nil { + return nil, err + } + defer rows.Close() + ids := make([]string, 0) + for rows.Next() { + var id string + if err := rows.Scan(&id); err != nil { + return nil, err + } + ids = append(ids, id) + } + if err := rows.Err(); err != nil { + return nil, err + } + return ids, nil +} + func readMessages(rows *sql.Rows) ([]*message, error) { defer rows.Close() messages := make([]*message, 0) diff --git a/server/file_cache.go b/server/file_cache.go index e23718c6..244142cf 100644 --- a/server/file_cache.go +++ b/server/file_cache.go @@ -28,18 +28,10 @@ func newFileCache(dir string, totalSizeLimit int64, fileSizeLimit int64) (*fileC if err := os.MkdirAll(dir, 0700); err != nil { return nil, err } - entries, err := os.ReadDir(dir) + size, err := dirSize(dir) if err != nil { return nil, err } - var size int64 - for _, e := range entries { - info, err := e.Info() - if err != nil { - return nil, err - } - size += info.Size() - } return &fileCache{ dir: dir, totalSizeCurrent: size, @@ -58,8 +50,8 @@ func (c *fileCache) Write(id string, in io.Reader, limiters ...*util.Limiter) (i return 0, err } defer f.Close() - log.Printf("remaining total: %d", c.remainingTotalSize()) - limiters = append(limiters, util.NewLimiter(c.remainingTotalSize()), util.NewLimiter(c.fileSizeLimit)) + log.Printf("remaining total: %d", c.Remaining()) + limiters = append(limiters, util.NewLimiter(c.Remaining()), util.NewLimiter(c.fileSizeLimit)) limitWriter := util.NewLimitWriter(f, limiters...) size, err := io.Copy(limitWriter, in) if err != nil { @@ -77,7 +69,40 @@ func (c *fileCache) Write(id string, in io.Reader, limiters ...*util.Limiter) (i } -func (c *fileCache) remainingTotalSize() int64 { +func (c *fileCache) Remove(ids []string) error { + var firstErr error + for _, id := range ids { + if err := c.removeFile(id); err != nil { + if firstErr == nil { + firstErr = err // Continue despite error; we want to delete as many as we can + } + } + } + size, err := dirSize(c.dir) + if err != nil { + return err + } + c.mu.Lock() + c.totalSizeCurrent = size + c.mu.Unlock() + return firstErr +} + +func (c *fileCache) removeFile(id string) error { + if !fileIDRegex.MatchString(id) { + return errInvalidFileID + } + file := filepath.Join(c.dir, id) + return os.Remove(file) +} + +func (c *fileCache) Size() int64 { + c.mu.Lock() + defer c.mu.Unlock() + return c.totalSizeCurrent +} + +func (c *fileCache) Remaining() int64 { c.mu.Lock() defer c.mu.Unlock() remaining := c.totalSizeLimit - c.totalSizeCurrent @@ -86,3 +111,19 @@ func (c *fileCache) remainingTotalSize() int64 { } return remaining } + +func dirSize(dir string) (int64, error) { + entries, err := os.ReadDir(dir) + if err != nil { + return 0, err + } + var size int64 + for _, e := range entries { + info, err := e.Info() + if err != nil { + return 0, err + } + size += info.Size() + } + return size, nil +} diff --git a/server/server.go b/server/server.go index afd7d38d..c3ee81fa 100644 --- a/server/server.go +++ b/server/server.go @@ -832,6 +832,16 @@ func (s *Server) updateStatsAndPrune() { } } + // Delete expired attachments + ids, err := s.cache.AttachmentsExpired() + if err == nil { + if err := s.fileCache.Remove(ids); err != nil { + log.Printf("error while deleting attachments: %s", err.Error()) + } + } else { + log.Printf("error retrieving expired attachments: %s", err.Error()) + } + // Prune message cache olderThan := time.Now().Add(-1 * s.config.CacheDuration) if err := s.cache.Prune(olderThan); err != nil {