mirror of
https://github.com/binwiederhier/ntfy.git
synced 2024-12-26 19:52:30 +01:00
Finish cache tests
This commit is contained in:
parent
b437a87266
commit
98c1ab9e86
8 changed files with 95 additions and 37 deletions
|
@ -17,5 +17,5 @@ type cache interface {
|
||||||
Messages(topic string, since sinceTime) ([]*message, error)
|
Messages(topic string, since sinceTime) ([]*message, error)
|
||||||
MessageCount(topic string) (int, error)
|
MessageCount(topic string) (int, error)
|
||||||
Topics() (map[string]*topic, error)
|
Topics() (map[string]*topic, error)
|
||||||
Prune(keep time.Duration) error
|
Prune(olderThan time.Time) error
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,26 +57,30 @@ func (s *memCache) MessageCount(topic string) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *memCache) Topics() (map[string]*topic, error) {
|
func (s *memCache) Topics() (map[string]*topic, error) {
|
||||||
// Hack since we know when this is called there are no messages!
|
s.mu.Lock()
|
||||||
return make(map[string]*topic), nil
|
defer s.mu.Unlock()
|
||||||
|
topics := make(map[string]*topic)
|
||||||
|
for topic := range s.messages {
|
||||||
|
topics[topic] = newTopic(topic)
|
||||||
|
}
|
||||||
|
return topics, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *memCache) Prune(keep time.Duration) error {
|
func (s *memCache) Prune(olderThan time.Time) error {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
for topic := range s.messages {
|
for topic := range s.messages {
|
||||||
s.pruneTopic(topic, keep)
|
s.pruneTopic(topic, olderThan)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *memCache) pruneTopic(topic string, keep time.Duration) {
|
func (s *memCache) pruneTopic(topic string, olderThan time.Time) {
|
||||||
for i, m := range s.messages[topic] {
|
messages := make([]*message, 0)
|
||||||
msgTime := time.Unix(m.Time, 0)
|
for _, m := range s.messages[topic] {
|
||||||
if time.Since(msgTime) < keep {
|
if m.Time >= olderThan.Unix() {
|
||||||
s.messages[topic] = s.messages[topic][i:]
|
messages = append(messages, m)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
s.messages[topic] = make([]*message, 0) // all messages expired
|
s.messages[topic] = messages
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,15 @@ import (
|
||||||
func TestMemCache_Messages(t *testing.T) {
|
func TestMemCache_Messages(t *testing.T) {
|
||||||
testCacheMessages(t, newMemCache())
|
testCacheMessages(t, newMemCache())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMemCache_Topics(t *testing.T) {
|
||||||
|
testCacheTopics(t, newMemCache())
|
||||||
|
}
|
||||||
|
|
||||||
func TestMemCache_MessagesTagsPrioAndTitle(t *testing.T) {
|
func TestMemCache_MessagesTagsPrioAndTitle(t *testing.T) {
|
||||||
testCacheMessagesTagsPrioAndTitle(t, newMemCache())
|
testCacheMessagesTagsPrioAndTitle(t, newMemCache())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMemCache_Prune(t *testing.T) {
|
||||||
|
testCachePrune(t, newMemCache())
|
||||||
|
}
|
||||||
|
|
|
@ -36,7 +36,7 @@ const (
|
||||||
`
|
`
|
||||||
selectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
|
selectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
|
||||||
selectMessageCountForTopicQuery = `SELECT COUNT(*) FROM messages WHERE topic = ?`
|
selectMessageCountForTopicQuery = `SELECT COUNT(*) FROM messages WHERE topic = ?`
|
||||||
selectTopicsQuery = `SELECT topic, MAX(time) FROM messages GROUP BY topic`
|
selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic`
|
||||||
)
|
)
|
||||||
|
|
||||||
// Schema management queries
|
// Schema management queries
|
||||||
|
@ -153,11 +153,10 @@ func (c *sqliteCache) Topics() (map[string]*topic, error) {
|
||||||
topics := make(map[string]*topic)
|
topics := make(map[string]*topic)
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var id string
|
var id string
|
||||||
var last int64
|
if err := rows.Scan(&id); err != nil {
|
||||||
if err := rows.Scan(&id, &last); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
topics[id] = newTopic(id, time.Unix(last, 0))
|
topics[id] = newTopic(id)
|
||||||
}
|
}
|
||||||
if err := rows.Err(); err != nil {
|
if err := rows.Err(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -165,8 +164,8 @@ func (c *sqliteCache) Topics() (map[string]*topic, error) {
|
||||||
return topics, nil
|
return topics, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *sqliteCache) Prune(keep time.Duration) error {
|
func (c *sqliteCache) Prune(olderThan time.Time) error {
|
||||||
_, err := c.db.Exec(pruneMessagesQuery, time.Now().Add(-1*keep).Unix())
|
_, err := c.db.Exec(pruneMessagesQuery, olderThan.Unix())
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,10 +9,18 @@ func TestSqliteCache_AddMessage(t *testing.T) {
|
||||||
testCacheMessages(t, newSqliteTestCache(t))
|
testCacheMessages(t, newSqliteTestCache(t))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSqliteCache_Topics(t *testing.T) {
|
||||||
|
testCacheTopics(t, newSqliteTestCache(t))
|
||||||
|
}
|
||||||
|
|
||||||
func TestSqliteCache_MessagesTagsPrioAndTitle(t *testing.T) {
|
func TestSqliteCache_MessagesTagsPrioAndTitle(t *testing.T) {
|
||||||
testCacheMessagesTagsPrioAndTitle(t, newSqliteTestCache(t))
|
testCacheMessagesTagsPrioAndTitle(t, newSqliteTestCache(t))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSqliteCache_Prune(t *testing.T) {
|
||||||
|
testCachePrune(t, newSqliteTestCache(t))
|
||||||
|
}
|
||||||
|
|
||||||
func newSqliteTestCache(t *testing.T) cache {
|
func newSqliteTestCache(t *testing.T) cache {
|
||||||
filename := filepath.Join(t.TempDir(), "cache.db")
|
filename := filepath.Join(t.TempDir(), "cache.db")
|
||||||
c, err := newSqliteCache(filename)
|
c, err := newSqliteCache(filename)
|
||||||
|
|
|
@ -65,6 +65,50 @@ func testCacheMessages(t *testing.T, c cache) {
|
||||||
assert.Empty(t, messages)
|
assert.Empty(t, messages)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testCacheTopics(t *testing.T, c cache) {
|
||||||
|
assert.Nil(t, c.AddMessage(newDefaultMessage("topic1", "my example message")))
|
||||||
|
assert.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 1")))
|
||||||
|
assert.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 2")))
|
||||||
|
assert.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 3")))
|
||||||
|
|
||||||
|
topics, err := c.Topics()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
assert.Equal(t, 2, len(topics))
|
||||||
|
assert.Equal(t, "topic1", topics["topic1"].ID)
|
||||||
|
assert.Equal(t, "topic2", topics["topic2"].ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testCachePrune(t *testing.T, c cache) {
|
||||||
|
m1 := newDefaultMessage("mytopic", "my message")
|
||||||
|
m1.Time = 1
|
||||||
|
|
||||||
|
m2 := newDefaultMessage("mytopic", "my other message")
|
||||||
|
m2.Time = 2
|
||||||
|
|
||||||
|
m3 := newDefaultMessage("another_topic", "and another one")
|
||||||
|
m3.Time = 1
|
||||||
|
|
||||||
|
assert.Nil(t, c.AddMessage(m1))
|
||||||
|
assert.Nil(t, c.AddMessage(m2))
|
||||||
|
assert.Nil(t, c.AddMessage(m3))
|
||||||
|
assert.Nil(t, c.Prune(time.Unix(2, 0)))
|
||||||
|
|
||||||
|
count, err := c.MessageCount("mytopic")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, 1, count)
|
||||||
|
|
||||||
|
count, err = c.MessageCount("another_topic")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, 0, count)
|
||||||
|
|
||||||
|
messages, err := c.Messages("mytopic", sinceAllMessages)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, 1, len(messages))
|
||||||
|
assert.Equal(t, "my other message", messages[0].Message)
|
||||||
|
}
|
||||||
|
|
||||||
func testCacheMessagesTagsPrioAndTitle(t *testing.T, c cache) {
|
func testCacheMessagesTagsPrioAndTitle(t *testing.T, c cache) {
|
||||||
m := newDefaultMessage("mytopic", "some message")
|
m := newDefaultMessage("mytopic", "some message")
|
||||||
m.Tags = []string{"tag1", "tag2"}
|
m.Tags = []string{"tag1", "tag2"}
|
||||||
|
|
|
@ -274,7 +274,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, _ *visito
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
m := newDefaultMessage(t.id, string(b))
|
m := newDefaultMessage(t.ID, string(b))
|
||||||
if m.Message == "" {
|
if m.Message == "" {
|
||||||
return errHTTPBadRequest
|
return errHTTPBadRequest
|
||||||
}
|
}
|
||||||
|
@ -442,7 +442,7 @@ func (s *Server) sendOldMessages(topics []*topic, since sinceTime, sub subscribe
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
for _, t := range topics {
|
for _, t := range topics {
|
||||||
messages, err := s.cache.Messages(t.id, since)
|
messages, err := s.cache.Messages(t.ID, since)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -468,11 +468,9 @@ func parseSince(r *http.Request) (sinceTime, error) {
|
||||||
}
|
}
|
||||||
if r.URL.Query().Get("since") == "all" {
|
if r.URL.Query().Get("since") == "all" {
|
||||||
return sinceAllMessages, nil
|
return sinceAllMessages, nil
|
||||||
}
|
} else if s, err := strconv.ParseInt(r.URL.Query().Get("since"), 10, 64); err == nil {
|
||||||
if s, err := strconv.ParseInt(r.URL.Query().Get("since"), 10, 64); err == nil {
|
|
||||||
return sinceTime(time.Unix(s, 0)), nil
|
return sinceTime(time.Unix(s, 0)), nil
|
||||||
}
|
} else if d, err := time.ParseDuration(r.URL.Query().Get("since")); err == nil {
|
||||||
if d, err := time.ParseDuration(r.URL.Query().Get("since")); err == nil {
|
|
||||||
return sinceTime(time.Now().Add(-1 * d)), nil
|
return sinceTime(time.Now().Add(-1 * d)), nil
|
||||||
}
|
}
|
||||||
return sinceNoMessages, errHTTPBadRequest
|
return sinceNoMessages, errHTTPBadRequest
|
||||||
|
@ -504,7 +502,7 @@ func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
|
||||||
if len(s.topics) >= s.config.GlobalTopicLimit {
|
if len(s.topics) >= s.config.GlobalTopicLimit {
|
||||||
return nil, errHTTPTooManyRequests
|
return nil, errHTTPTooManyRequests
|
||||||
}
|
}
|
||||||
s.topics[id] = newTopic(id, time.Now())
|
s.topics[id] = newTopic(id)
|
||||||
if s.firebase != nil {
|
if s.firebase != nil {
|
||||||
s.topics[id].Subscribe(s.firebase)
|
s.topics[id].Subscribe(s.firebase)
|
||||||
}
|
}
|
||||||
|
@ -526,7 +524,8 @@ func (s *Server) updateStatsAndExpire() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prune cache
|
// Prune cache
|
||||||
if err := s.cache.Prune(s.config.CacheDuration); err != nil {
|
olderThan := time.Now().Add(-1 * s.config.CacheDuration)
|
||||||
|
if err := s.cache.Prune(olderThan); err != nil {
|
||||||
log.Printf("error pruning cache: %s", err.Error())
|
log.Printf("error pruning cache: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -534,13 +533,13 @@ func (s *Server) updateStatsAndExpire() {
|
||||||
var subscribers, messages int
|
var subscribers, messages int
|
||||||
for _, t := range s.topics {
|
for _, t := range s.topics {
|
||||||
subs := t.Subscribers()
|
subs := t.Subscribers()
|
||||||
msgs, err := s.cache.MessageCount(t.id)
|
msgs, err := s.cache.MessageCount(t.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("cannot get stats for topic %s: %s", t.id, err.Error())
|
log.Printf("cannot get stats for topic %s: %s", t.ID, err.Error())
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if msgs == 0 && (subs == 0 || (s.firebase != nil && subs == 1)) { // Firebase is a subscriber!
|
if msgs == 0 && (subs == 0 || (s.firebase != nil && subs == 1)) { // Firebase is a subscriber!
|
||||||
delete(s.topics, t.id)
|
delete(s.topics, t.ID)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
subscribers += subs
|
subscribers += subs
|
||||||
|
|
|
@ -4,14 +4,12 @@ import (
|
||||||
"log"
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// topic represents a channel to which subscribers can subscribe, and publishers
|
// topic represents a channel to which subscribers can subscribe, and publishers
|
||||||
// can publish a message
|
// can publish a message
|
||||||
type topic struct {
|
type topic struct {
|
||||||
id string
|
ID string
|
||||||
last time.Time
|
|
||||||
subscribers map[int]subscriber
|
subscribers map[int]subscriber
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
@ -20,10 +18,9 @@ type topic struct {
|
||||||
type subscriber func(msg *message) error
|
type subscriber func(msg *message) error
|
||||||
|
|
||||||
// newTopic creates a new topic
|
// newTopic creates a new topic
|
||||||
func newTopic(id string, last time.Time) *topic {
|
func newTopic(id string) *topic {
|
||||||
return &topic{
|
return &topic{
|
||||||
id: id,
|
ID: id,
|
||||||
last: last,
|
|
||||||
subscribers: make(map[int]subscriber),
|
subscribers: make(map[int]subscriber),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -34,7 +31,6 @@ func (t *topic) Subscribe(s subscriber) int {
|
||||||
defer t.mu.Unlock()
|
defer t.mu.Unlock()
|
||||||
subscriberID := rand.Int()
|
subscriberID := rand.Int()
|
||||||
t.subscribers[subscriberID] = s
|
t.subscribers[subscriberID] = s
|
||||||
t.last = time.Now()
|
|
||||||
return subscriberID
|
return subscriberID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,7 +46,6 @@ func (t *topic) Publish(m *message) error {
|
||||||
go func() {
|
go func() {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
defer t.mu.Unlock()
|
defer t.mu.Unlock()
|
||||||
t.last = time.Now()
|
|
||||||
for _, s := range t.subscribers {
|
for _, s := range t.subscribers {
|
||||||
if err := s(m); err != nil {
|
if err := s(m); err != nil {
|
||||||
log.Printf("error publishing message to subscriber")
|
log.Printf("error publishing message to subscriber")
|
||||||
|
|
Loading…
Reference in a new issue