From 85f2252a77a2629434460a70e45b281b5dbebd2a Mon Sep 17 00:00:00 2001
From: Philipp Heckel <pheckel@datto.com>
Date: Tue, 21 Jun 2022 19:07:27 -0400
Subject: [PATCH] WIP: Shorter lock, for #338

---
 server/message_cache.go      | 26 ++++++++++++++------------
 server/message_cache_test.go | 10 +++++-----
 server/server.go             | 32 ++++++++++++++++++++++----------
 3 files changed, 41 insertions(+), 27 deletions(-)

diff --git a/server/message_cache.go b/server/message_cache.go
index afd4bf17..d6024c80 100644
--- a/server/message_cache.go
+++ b/server/message_cache.go
@@ -82,7 +82,7 @@ const (
 	`
 	updateMessagePublishedQuery     = `UPDATE messages SET published = 1 WHERE mid = ?`
 	selectMessagesCountQuery        = `SELECT COUNT(*) FROM messages`
-	selectMessageCountForTopicQuery = `SELECT COUNT(*) FROM messages WHERE topic = ?`
+	selectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic`
 	selectTopicsQuery               = `SELECT topic FROM messages GROUP BY topic`
 	selectAttachmentsSizeQuery      = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?`
 	selectAttachmentsExpiredQuery   = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires < ?`
@@ -332,22 +332,24 @@ func (c *messageCache) MarkPublished(m *message) error {
 	return err
 }
 
-func (c *messageCache) MessageCount(topic string) (int, error) {
-	rows, err := c.db.Query(selectMessageCountForTopicQuery, topic)
+func (c *messageCache) MessageCounts() (map[string]int, error) {
+	rows, err := c.db.Query(selectMessageCountPerTopicQuery)
 	if err != nil {
-		return 0, err
+		return nil, err
 	}
 	defer rows.Close()
+	var topic string
 	var count int
-	if !rows.Next() {
-		return 0, errors.New("no rows found")
+	counts := make(map[string]int)
+	for rows.Next() {
+		if err := rows.Scan(&topic, &count); err != nil {
+			return nil, err
+		} else if err := rows.Err(); err != nil {
+			return nil, err
+		}
+		counts[topic] = count
 	}
-	if err := rows.Scan(&count); err != nil {
-		return 0, err
-	} else if err := rows.Err(); err != nil {
-		return 0, err
-	}
-	return count, nil
+	return counts, nil
 }
 
 func (c *messageCache) Topics() (map[string]*topic, error) {
diff --git a/server/message_cache_test.go b/server/message_cache_test.go
index 398f21e4..9132088e 100644
--- a/server/message_cache_test.go
+++ b/server/message_cache_test.go
@@ -34,7 +34,7 @@ func testCacheMessages(t *testing.T, c *messageCache) {
 	require.Equal(t, errUnexpectedMessageType, c.AddMessage(newOpenMessage("example")))      // These should not be added!
 
 	// mytopic: count
-	count, err := c.MessageCount("mytopic")
+	count, err := c.MessageCounts("mytopic")
 	require.Nil(t, err)
 	require.Equal(t, 2, count)
 
@@ -66,7 +66,7 @@ func testCacheMessages(t *testing.T, c *messageCache) {
 	require.Equal(t, "my other message", messages[0].Message)
 
 	// example: count
-	count, err = c.MessageCount("example")
+	count, err = c.MessageCounts("example")
 	require.Nil(t, err)
 	require.Equal(t, 1, count)
 
@@ -75,7 +75,7 @@ func testCacheMessages(t *testing.T, c *messageCache) {
 	require.Equal(t, "my example message", messages[0].Message)
 
 	// non-existing: count
-	count, err = c.MessageCount("doesnotexist")
+	count, err = c.MessageCounts("doesnotexist")
 	require.Nil(t, err)
 	require.Equal(t, 0, count)
 
@@ -255,11 +255,11 @@ func testCachePrune(t *testing.T, c *messageCache) {
 	require.Nil(t, c.AddMessage(m3))
 	require.Nil(t, c.Prune(time.Unix(2, 0)))
 
-	count, err := c.MessageCount("mytopic")
+	count, err := c.MessageCounts("mytopic")
 	require.Nil(t, err)
 	require.Equal(t, 1, count)
 
-	count, err = c.MessageCount("another_topic")
+	count, err = c.MessageCounts("another_topic")
 	require.Nil(t, err)
 	require.Equal(t, 0, count)
 
diff --git a/server/server.go b/server/server.go
index 4d028d91..2c887cc1 100644
--- a/server/server.go
+++ b/server/server.go
@@ -1080,10 +1080,13 @@ func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
 }
 
 func (s *Server) updateStatsAndPrune() {
-	s.mu.Lock()
-	defer s.mu.Unlock()
+	log.Debug("Manager: Running cleanup")
+
+	// WARNING: Make sure to only selectively lock with the mutex, and be aware that this
+	//          there is no mutex for the entire function.
 
 	// Expire visitors from rate visitors map
+	s.mu.Lock()
 	staleVisitors := 0
 	for ip, v := range s.visitors {
 		if v.Stale() {
@@ -1092,6 +1095,7 @@ func (s *Server) updateStatsAndPrune() {
 			staleVisitors++
 		}
 	}
+	s.mu.Unlock()
 	log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors)
 
 	// Delete expired attachments
@@ -1116,22 +1120,30 @@ func (s *Server) updateStatsAndPrune() {
 		log.Warn("Manager: Error pruning cache: %s", err.Error())
 	}
 
+	// Message count per topic
+	var messages int
+	messageCounts, err := s.messageCache.MessageCounts()
+	if err != nil {
+		log.Warn("Manager: Cannot get message counts: %s", err.Error())
+		messageCounts = make(map[string]int) // Empty, so we can continue
+	}
+	for _, count := range messageCounts {
+		messages += count
+	}
+
 	// Prune old topics, remove subscriptions without subscribers
-	var subscribers, messages int
+	s.mu.Lock()
+	var subscribers int
 	for _, t := range s.topics {
 		subs := t.Subscribers()
-		msgs, err := s.messageCache.MessageCount(t.ID)
-		if err != nil {
-			log.Warn("Manager: Cannot get stats for topic %s: %s", t.ID, err.Error())
-			continue
-		}
-		if msgs == 0 && subs == 0 {
+		msgs, exists := messageCounts[t.ID]
+		if subs == 0 && (!exists || msgs == 0) {
 			delete(s.topics, t.ID)
 			continue
 		}
 		subscribers += subs
-		messages += msgs
 	}
+	s.mu.Unlock()
 
 	// Mail stats
 	var receivedMailTotal, receivedMailSuccess, receivedMailFailure int64