From b4933a5645f296c519f12a8f78bdeb7cfd3f17b9 Mon Sep 17 00:00:00 2001
From: Philipp Heckel <pheckel@datto.com>
Date: Tue, 15 Nov 2022 14:24:56 -0500
Subject: [PATCH] WIP: Batch message INSERTs

---
 server/message_cache.go     | 28 +++++++++++++++----
 server/server.go            |  6 ++--
 util/batching_queue.go      | 56 +++++++++++++++++++++++++++++++++++++
 util/batching_queue_test.go | 25 +++++++++++++++++
 4 files changed, 107 insertions(+), 8 deletions(-)
 create mode 100644 util/batching_queue.go
 create mode 100644 util/batching_queue_test.go

diff --git a/server/message_cache.go b/server/message_cache.go
index f4433399..ec710e4f 100644
--- a/server/message_cache.go
+++ b/server/message_cache.go
@@ -188,8 +188,9 @@ const (
 )
 
 type messageCache struct {
-	db  *sql.DB
-	nop bool
+	db    *sql.DB
+	queue *util.BatchingQueue[*message]
+	nop   bool
 }
 
 // newSqliteCache creates a SQLite file-backed cache
@@ -201,10 +202,21 @@ func newSqliteCache(filename, startupQueries string, nop bool) (*messageCache, e
 	if err := setupCacheDB(db, startupQueries); err != nil {
 		return nil, err
 	}
-	return &messageCache{
-		db:  db,
-		nop: nop,
-	}, nil
+	queue := util.NewBatchingQueue[*message](20, 500*time.Millisecond)
+	cache := &messageCache{
+		db:    db,
+		queue: queue,
+		nop:   nop,
+	}
+	go func() {
+		for messages := range queue.Pop() {
+			log.Debug("Adding %d messages to cache", len(messages))
+			if err := cache.addMessages(messages); err != nil {
+				log.Error("error: %s", err.Error())
+			}
+		}
+	}()
+	return cache, nil
 }
 
 // newMemCache creates an in-memory cache
@@ -232,6 +244,10 @@ func (c *messageCache) AddMessage(m *message) error {
 	return c.addMessages([]*message{m})
 }
 
+func (c *messageCache) QueueMessage(m *message) {
+	c.queue.Push(m)
+}
+
 func (c *messageCache) addMessages(ms []*message) error {
 	if c.nop {
 		return nil
diff --git a/server/server.go b/server/server.go
index ef09100d..b90b7630 100644
--- a/server/server.go
+++ b/server/server.go
@@ -491,9 +491,11 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
 		log.Debug("%s Message delayed, will process later", logMessagePrefix(v, m))
 	}
 	if cache {
-		if err := s.messageCache.AddMessage(m); err != nil {
+		log.Trace("%s Queuing for cache", logMessagePrefix(v, m))
+		s.messageCache.QueueMessage(m)
+		/*if err := s.messageCache.AddMessage(m); err != nil {
 			return nil, err
-		}
+		}*/
 	}
 	s.mu.Lock()
 	s.messages++
diff --git a/util/batching_queue.go b/util/batching_queue.go
new file mode 100644
index 00000000..78116470
--- /dev/null
+++ b/util/batching_queue.go
@@ -0,0 +1,56 @@
+package util
+
+import (
+	"sync"
+	"time"
+)
+
+type BatchingQueue[T any] struct {
+	batchSize int
+	timeout   time.Duration
+	in        []T
+	out       chan []T
+	mu        sync.Mutex
+}
+
+func NewBatchingQueue[T any](batchSize int, timeout time.Duration) *BatchingQueue[T] {
+	q := &BatchingQueue[T]{
+		batchSize: batchSize,
+		timeout:   timeout,
+		in:        make([]T, 0),
+		out:       make(chan []T),
+	}
+	ticker := time.NewTicker(timeout)
+	go func() {
+		for range ticker.C {
+			elements := q.popAll()
+			if len(elements) > 0 {
+				q.out <- elements
+			}
+		}
+	}()
+	return q
+}
+
+func (c *BatchingQueue[T]) Push(element T) {
+	c.mu.Lock()
+	c.in = append(c.in, element)
+	limitReached := len(c.in) == c.batchSize
+	c.mu.Unlock()
+	if limitReached {
+		c.out <- c.popAll()
+	}
+}
+
+func (c *BatchingQueue[T]) Pop() <-chan []T {
+	return c.out
+}
+
+func (c *BatchingQueue[T]) popAll() []T {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	elements := make([]T, len(c.in))
+	copy(elements, c.in)
+	c.in = c.in[:0]
+	return elements
+}
diff --git a/util/batching_queue_test.go b/util/batching_queue_test.go
new file mode 100644
index 00000000..46bc06b8
--- /dev/null
+++ b/util/batching_queue_test.go
@@ -0,0 +1,25 @@
+package util_test
+
+import (
+	"fmt"
+	"heckel.io/ntfy/util"
+	"math/rand"
+	"testing"
+	"time"
+)
+
+func TestConcurrentQueue_Next(t *testing.T) {
+	q := util.NewBatchingQueue[int](25, 200*time.Millisecond)
+	go func() {
+		for batch := range q.Pop() {
+			fmt.Printf("Batch of %d items\n", len(batch))
+		}
+	}()
+	for i := 0; i < 1000; i++ {
+		go func(i int) {
+			time.Sleep(time.Duration(rand.Intn(1000)) * time.Millisecond)
+			q.Push(i)
+		}(i)
+	}
+	time.Sleep(2 * time.Second)
+}