diff --git a/server/message_cache.go b/server/message_cache.go index f4433399..ec710e4f 100644 --- a/server/message_cache.go +++ b/server/message_cache.go @@ -188,8 +188,9 @@ const ( ) type messageCache struct { - db *sql.DB - nop bool + db *sql.DB + queue *util.BatchingQueue[*message] + nop bool } // newSqliteCache creates a SQLite file-backed cache @@ -201,10 +202,21 @@ func newSqliteCache(filename, startupQueries string, nop bool) (*messageCache, e if err := setupCacheDB(db, startupQueries); err != nil { return nil, err } - return &messageCache{ - db: db, - nop: nop, - }, nil + queue := util.NewBatchingQueue[*message](20, 500*time.Millisecond) + cache := &messageCache{ + db: db, + queue: queue, + nop: nop, + } + go func() { + for messages := range queue.Pop() { + log.Debug("Adding %d messages to cache", len(messages)) + if err := cache.addMessages(messages); err != nil { + log.Error("error: %s", err.Error()) + } + } + }() + return cache, nil } // newMemCache creates an in-memory cache @@ -232,6 +244,10 @@ func (c *messageCache) AddMessage(m *message) error { return c.addMessages([]*message{m}) } +func (c *messageCache) QueueMessage(m *message) { + c.queue.Push(m) +} + func (c *messageCache) addMessages(ms []*message) error { if c.nop { return nil diff --git a/server/server.go b/server/server.go index ef09100d..b90b7630 100644 --- a/server/server.go +++ b/server/server.go @@ -491,9 +491,11 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes log.Debug("%s Message delayed, will process later", logMessagePrefix(v, m)) } if cache { - if err := s.messageCache.AddMessage(m); err != nil { + log.Trace("%s Queuing for cache", logMessagePrefix(v, m)) + s.messageCache.QueueMessage(m) + /*if err := s.messageCache.AddMessage(m); err != nil { return nil, err - } + }*/ } s.mu.Lock() s.messages++ diff --git a/util/batching_queue.go b/util/batching_queue.go new file mode 100644 index 00000000..78116470 --- /dev/null +++ b/util/batching_queue.go @@ -0,0 +1,56 @@ +package util + +import ( + "sync" + "time" +) + +type BatchingQueue[T any] struct { + batchSize int + timeout time.Duration + in []T + out chan []T + mu sync.Mutex +} + +func NewBatchingQueue[T any](batchSize int, timeout time.Duration) *BatchingQueue[T] { + q := &BatchingQueue[T]{ + batchSize: batchSize, + timeout: timeout, + in: make([]T, 0), + out: make(chan []T), + } + ticker := time.NewTicker(timeout) + go func() { + for range ticker.C { + elements := q.popAll() + if len(elements) > 0 { + q.out <- elements + } + } + }() + return q +} + +func (c *BatchingQueue[T]) Push(element T) { + c.mu.Lock() + c.in = append(c.in, element) + limitReached := len(c.in) == c.batchSize + c.mu.Unlock() + if limitReached { + c.out <- c.popAll() + } +} + +func (c *BatchingQueue[T]) Pop() <-chan []T { + return c.out +} + +func (c *BatchingQueue[T]) popAll() []T { + c.mu.Lock() + defer c.mu.Unlock() + elements := make([]T, len(c.in)) + copy(elements, c.in) + c.in = c.in[:0] + return elements +} diff --git a/util/batching_queue_test.go b/util/batching_queue_test.go new file mode 100644 index 00000000..46bc06b8 --- /dev/null +++ b/util/batching_queue_test.go @@ -0,0 +1,25 @@ +package util_test + +import ( + "fmt" + "heckel.io/ntfy/util" + "math/rand" + "testing" + "time" +) + +func TestConcurrentQueue_Next(t *testing.T) { + q := util.NewBatchingQueue[int](25, 200*time.Millisecond) + go func() { + for batch := range q.Pop() { + fmt.Printf("Batch of %d items\n", len(batch)) + } + }() + for i := 0; i < 1000; i++ { + go func(i int) { + time.Sleep(time.Duration(rand.Intn(1000)) * time.Millisecond) + q.Push(i) + }(i) + } + time.Sleep(2 * time.Second) +}