1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2024-12-27 12:12:28 +01:00

Finish cache tests

This commit is contained in:
Philipp Heckel 2021-12-08 22:57:31 -05:00
parent b437a87266
commit 98c1ab9e86
8 changed files with 95 additions and 37 deletions

View file

@ -17,5 +17,5 @@ type cache interface {
Messages(topic string, since sinceTime) ([]*message, error) Messages(topic string, since sinceTime) ([]*message, error)
MessageCount(topic string) (int, error) MessageCount(topic string) (int, error)
Topics() (map[string]*topic, error) Topics() (map[string]*topic, error)
Prune(keep time.Duration) error Prune(olderThan time.Time) error
} }

View file

@ -57,26 +57,30 @@ func (s *memCache) MessageCount(topic string) (int, error) {
} }
func (s *memCache) Topics() (map[string]*topic, error) { func (s *memCache) Topics() (map[string]*topic, error) {
// Hack since we know when this is called there are no messages! s.mu.Lock()
return make(map[string]*topic), nil defer s.mu.Unlock()
topics := make(map[string]*topic)
for topic := range s.messages {
topics[topic] = newTopic(topic)
}
return topics, nil
} }
func (s *memCache) Prune(keep time.Duration) error { func (s *memCache) Prune(olderThan time.Time) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
for topic := range s.messages { for topic := range s.messages {
s.pruneTopic(topic, keep) s.pruneTopic(topic, olderThan)
} }
return nil return nil
} }
func (s *memCache) pruneTopic(topic string, keep time.Duration) { func (s *memCache) pruneTopic(topic string, olderThan time.Time) {
for i, m := range s.messages[topic] { messages := make([]*message, 0)
msgTime := time.Unix(m.Time, 0) for _, m := range s.messages[topic] {
if time.Since(msgTime) < keep { if m.Time >= olderThan.Unix() {
s.messages[topic] = s.messages[topic][i:] messages = append(messages, m)
return
} }
} }
s.messages[topic] = make([]*message, 0) // all messages expired s.messages[topic] = messages
} }

View file

@ -7,6 +7,15 @@ import (
func TestMemCache_Messages(t *testing.T) { func TestMemCache_Messages(t *testing.T) {
testCacheMessages(t, newMemCache()) testCacheMessages(t, newMemCache())
} }
func TestMemCache_Topics(t *testing.T) {
testCacheTopics(t, newMemCache())
}
func TestMemCache_MessagesTagsPrioAndTitle(t *testing.T) { func TestMemCache_MessagesTagsPrioAndTitle(t *testing.T) {
testCacheMessagesTagsPrioAndTitle(t, newMemCache()) testCacheMessagesTagsPrioAndTitle(t, newMemCache())
} }
func TestMemCache_Prune(t *testing.T) {
testCachePrune(t, newMemCache())
}

View file

@ -36,7 +36,7 @@ const (
` `
selectMessagesCountQuery = `SELECT COUNT(*) FROM messages` selectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
selectMessageCountForTopicQuery = `SELECT COUNT(*) FROM messages WHERE topic = ?` selectMessageCountForTopicQuery = `SELECT COUNT(*) FROM messages WHERE topic = ?`
selectTopicsQuery = `SELECT topic, MAX(time) FROM messages GROUP BY topic` selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic`
) )
// Schema management queries // Schema management queries
@ -153,11 +153,10 @@ func (c *sqliteCache) Topics() (map[string]*topic, error) {
topics := make(map[string]*topic) topics := make(map[string]*topic)
for rows.Next() { for rows.Next() {
var id string var id string
var last int64 if err := rows.Scan(&id); err != nil {
if err := rows.Scan(&id, &last); err != nil {
return nil, err return nil, err
} }
topics[id] = newTopic(id, time.Unix(last, 0)) topics[id] = newTopic(id)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, err return nil, err
@ -165,8 +164,8 @@ func (c *sqliteCache) Topics() (map[string]*topic, error) {
return topics, nil return topics, nil
} }
func (c *sqliteCache) Prune(keep time.Duration) error { func (c *sqliteCache) Prune(olderThan time.Time) error {
_, err := c.db.Exec(pruneMessagesQuery, time.Now().Add(-1*keep).Unix()) _, err := c.db.Exec(pruneMessagesQuery, olderThan.Unix())
return err return err
} }

View file

@ -9,10 +9,18 @@ func TestSqliteCache_AddMessage(t *testing.T) {
testCacheMessages(t, newSqliteTestCache(t)) testCacheMessages(t, newSqliteTestCache(t))
} }
func TestSqliteCache_Topics(t *testing.T) {
testCacheTopics(t, newSqliteTestCache(t))
}
func TestSqliteCache_MessagesTagsPrioAndTitle(t *testing.T) { func TestSqliteCache_MessagesTagsPrioAndTitle(t *testing.T) {
testCacheMessagesTagsPrioAndTitle(t, newSqliteTestCache(t)) testCacheMessagesTagsPrioAndTitle(t, newSqliteTestCache(t))
} }
func TestSqliteCache_Prune(t *testing.T) {
testCachePrune(t, newSqliteTestCache(t))
}
func newSqliteTestCache(t *testing.T) cache { func newSqliteTestCache(t *testing.T) cache {
filename := filepath.Join(t.TempDir(), "cache.db") filename := filepath.Join(t.TempDir(), "cache.db")
c, err := newSqliteCache(filename) c, err := newSqliteCache(filename)

View file

@ -65,6 +65,50 @@ func testCacheMessages(t *testing.T, c cache) {
assert.Empty(t, messages) assert.Empty(t, messages)
} }
func testCacheTopics(t *testing.T, c cache) {
assert.Nil(t, c.AddMessage(newDefaultMessage("topic1", "my example message")))
assert.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 1")))
assert.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 2")))
assert.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 3")))
topics, err := c.Topics()
if err != nil {
t.Fatal(err)
}
assert.Equal(t, 2, len(topics))
assert.Equal(t, "topic1", topics["topic1"].ID)
assert.Equal(t, "topic2", topics["topic2"].ID)
}
func testCachePrune(t *testing.T, c cache) {
m1 := newDefaultMessage("mytopic", "my message")
m1.Time = 1
m2 := newDefaultMessage("mytopic", "my other message")
m2.Time = 2
m3 := newDefaultMessage("another_topic", "and another one")
m3.Time = 1
assert.Nil(t, c.AddMessage(m1))
assert.Nil(t, c.AddMessage(m2))
assert.Nil(t, c.AddMessage(m3))
assert.Nil(t, c.Prune(time.Unix(2, 0)))
count, err := c.MessageCount("mytopic")
assert.Nil(t, err)
assert.Equal(t, 1, count)
count, err = c.MessageCount("another_topic")
assert.Nil(t, err)
assert.Equal(t, 0, count)
messages, err := c.Messages("mytopic", sinceAllMessages)
assert.Nil(t, err)
assert.Equal(t, 1, len(messages))
assert.Equal(t, "my other message", messages[0].Message)
}
func testCacheMessagesTagsPrioAndTitle(t *testing.T, c cache) { func testCacheMessagesTagsPrioAndTitle(t *testing.T, c cache) {
m := newDefaultMessage("mytopic", "some message") m := newDefaultMessage("mytopic", "some message")
m.Tags = []string{"tag1", "tag2"} m.Tags = []string{"tag1", "tag2"}

View file

@ -274,7 +274,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, _ *visito
if err != nil { if err != nil {
return err return err
} }
m := newDefaultMessage(t.id, string(b)) m := newDefaultMessage(t.ID, string(b))
if m.Message == "" { if m.Message == "" {
return errHTTPBadRequest return errHTTPBadRequest
} }
@ -442,7 +442,7 @@ func (s *Server) sendOldMessages(topics []*topic, since sinceTime, sub subscribe
return nil return nil
} }
for _, t := range topics { for _, t := range topics {
messages, err := s.cache.Messages(t.id, since) messages, err := s.cache.Messages(t.ID, since)
if err != nil { if err != nil {
return err return err
} }
@ -468,11 +468,9 @@ func parseSince(r *http.Request) (sinceTime, error) {
} }
if r.URL.Query().Get("since") == "all" { if r.URL.Query().Get("since") == "all" {
return sinceAllMessages, nil return sinceAllMessages, nil
} } else if s, err := strconv.ParseInt(r.URL.Query().Get("since"), 10, 64); err == nil {
if s, err := strconv.ParseInt(r.URL.Query().Get("since"), 10, 64); err == nil {
return sinceTime(time.Unix(s, 0)), nil return sinceTime(time.Unix(s, 0)), nil
} } else if d, err := time.ParseDuration(r.URL.Query().Get("since")); err == nil {
if d, err := time.ParseDuration(r.URL.Query().Get("since")); err == nil {
return sinceTime(time.Now().Add(-1 * d)), nil return sinceTime(time.Now().Add(-1 * d)), nil
} }
return sinceNoMessages, errHTTPBadRequest return sinceNoMessages, errHTTPBadRequest
@ -504,7 +502,7 @@ func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
if len(s.topics) >= s.config.GlobalTopicLimit { if len(s.topics) >= s.config.GlobalTopicLimit {
return nil, errHTTPTooManyRequests return nil, errHTTPTooManyRequests
} }
s.topics[id] = newTopic(id, time.Now()) s.topics[id] = newTopic(id)
if s.firebase != nil { if s.firebase != nil {
s.topics[id].Subscribe(s.firebase) s.topics[id].Subscribe(s.firebase)
} }
@ -526,7 +524,8 @@ func (s *Server) updateStatsAndExpire() {
} }
// Prune cache // Prune cache
if err := s.cache.Prune(s.config.CacheDuration); err != nil { olderThan := time.Now().Add(-1 * s.config.CacheDuration)
if err := s.cache.Prune(olderThan); err != nil {
log.Printf("error pruning cache: %s", err.Error()) log.Printf("error pruning cache: %s", err.Error())
} }
@ -534,13 +533,13 @@ func (s *Server) updateStatsAndExpire() {
var subscribers, messages int var subscribers, messages int
for _, t := range s.topics { for _, t := range s.topics {
subs := t.Subscribers() subs := t.Subscribers()
msgs, err := s.cache.MessageCount(t.id) msgs, err := s.cache.MessageCount(t.ID)
if err != nil { if err != nil {
log.Printf("cannot get stats for topic %s: %s", t.id, err.Error()) log.Printf("cannot get stats for topic %s: %s", t.ID, err.Error())
continue continue
} }
if msgs == 0 && (subs == 0 || (s.firebase != nil && subs == 1)) { // Firebase is a subscriber! if msgs == 0 && (subs == 0 || (s.firebase != nil && subs == 1)) { // Firebase is a subscriber!
delete(s.topics, t.id) delete(s.topics, t.ID)
continue continue
} }
subscribers += subs subscribers += subs

View file

@ -4,14 +4,12 @@ import (
"log" "log"
"math/rand" "math/rand"
"sync" "sync"
"time"
) )
// topic represents a channel to which subscribers can subscribe, and publishers // topic represents a channel to which subscribers can subscribe, and publishers
// can publish a message // can publish a message
type topic struct { type topic struct {
id string ID string
last time.Time
subscribers map[int]subscriber subscribers map[int]subscriber
mu sync.Mutex mu sync.Mutex
} }
@ -20,10 +18,9 @@ type topic struct {
type subscriber func(msg *message) error type subscriber func(msg *message) error
// newTopic creates a new topic // newTopic creates a new topic
func newTopic(id string, last time.Time) *topic { func newTopic(id string) *topic {
return &topic{ return &topic{
id: id, ID: id,
last: last,
subscribers: make(map[int]subscriber), subscribers: make(map[int]subscriber),
} }
} }
@ -34,7 +31,6 @@ func (t *topic) Subscribe(s subscriber) int {
defer t.mu.Unlock() defer t.mu.Unlock()
subscriberID := rand.Int() subscriberID := rand.Int()
t.subscribers[subscriberID] = s t.subscribers[subscriberID] = s
t.last = time.Now()
return subscriberID return subscriberID
} }
@ -50,7 +46,6 @@ func (t *topic) Publish(m *message) error {
go func() { go func() {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
t.last = time.Now()
for _, s := range t.subscribers { for _, s := range t.subscribers {
if err := s(m); err != nil { if err := s(m); err != nil {
log.Printf("error publishing message to subscriber") log.Printf("error publishing message to subscriber")