diff --git a/server/message_cache.go b/server/message_cache.go index 8e787db3..c685d6a4 100644 --- a/server/message_cache.go +++ b/server/message_cache.go @@ -65,13 +65,13 @@ const ( selectMessagesSinceIDQuery = ` SELECT mid, time, updated, topic, message, title, priority, tags, click, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_owner, encoding FROM messages - WHERE topic = ? AND id > ? AND published = 1 + WHERE topic = ? AND id >= ? AND published = 1 ORDER BY time, id ` selectMessagesSinceIDIncludeScheduledQuery = ` SELECT mid, time, updated, topic, message, title, priority, tags, click, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_owner, encoding FROM messages - WHERE topic = ? AND (id > ? OR published = 0) + WHERE topic = ? AND (id >= ? OR published = 0) ORDER BY time, id ` selectMessagesDueQuery = ` @@ -95,7 +95,7 @@ const ( // Schema management queries const ( - currentSchemaVersion = 5 + currentSchemaVersion = 6 createSchemaVersionTableQuery = ` CREATE TABLE IF NOT EXISTS schemaVersion ( id INT PRIMARY KEY, @@ -173,6 +173,11 @@ const ( ALTER TABLE messages_new RENAME TO messages; COMMIT; ` + + // 5 -> 6 + migrate5To6AlterMessagesTableQuery = ` + ALTER TABLE messages ADD COLUMN updated INT NOT NULL DEFAULT (0); + ` ) type messageCache struct { @@ -326,7 +331,15 @@ func (c *messageCache) messagesSinceID(topic string, since sinceMarker, schedule if err != nil { return nil, err } - return readMessages(rows) + messages, err := readMessages(rows) + if err != nil { + return nil, err + } else if len(messages) == 0 { + return messages, nil + } else if since.IsTime() && messages[0].Updated > since.Time().Unix() { + return messages, nil + } + return messages[1:], nil // Do not include row with ID itself } func (c *messageCache) MessagesDue() ([]*message, error) { @@ -536,6 +549,8 @@ func setupCacheDB(db *sql.DB) error { return migrateFrom3(db) } else if schemaVersion == 4 { return migrateFrom4(db) + } else if schemaVersion == 5 { + return migrateFrom5(db) } return fmt.Errorf("unexpected schema version found: %d", schemaVersion) } @@ -608,5 +623,16 @@ func migrateFrom4(db *sql.DB) error { if _, err := db.Exec(updateSchemaVersion, 5); err != nil { return err } + return migrateFrom5(db) +} + +func migrateFrom5(db *sql.DB) error { + log.Print("Migrating cache database schema: from 5 to 6") + if _, err := db.Exec(migrate5To6AlterMessagesTableQuery); err != nil { + return err + } + if _, err := db.Exec(updateSchemaVersion, 6); err != nil { + return err + } return nil // Update this when a new version is added } diff --git a/server/server.go b/server/server.go index 0ef6d89d..e07a647e 100644 --- a/server/server.go +++ b/server/server.go @@ -473,8 +473,14 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca } // TODO more restrictions firebase = readBoolParam(r, true, "x-firebase", "firebase") - m.Title = readParam(r, "x-title", "title", "t") - m.Click = readParam(r, "x-click", "click") + title := readParam(r, "x-title", "title", "t") + if title != "" { + m.Title = title + } + click := readParam(r, "x-click", "click") + if click != "" { + m.Click = click + } filename := readParam(r, "x-filename", "filename", "file", "f") attach := readParam(r, "x-attach", "attach", "a") if attach != "" || filename != "" { @@ -514,9 +520,11 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca if messageStr != "" { m.Message = messageStr } - m.Priority, err = util.ParsePriority(readParam(r, "x-priority", "priority", "prio", "p")) + priority, err := util.ParsePriority(readParam(r, "x-priority", "priority", "prio", "p")) if err != nil { return false, false, "", false, errHTTPBadRequestPriorityInvalid + } else if priority > 0 { + m.Priority = priority } tagsStr := readParam(r, "x-tags", "tags", "tag", "ta") if tagsStr != "" { @@ -895,6 +903,13 @@ func parseSince(r *http.Request, poll bool) (sinceMarker, error) { return sinceNoMessages, nil } + // ID/timestamp + parts := strings.Split(since, "/") + if len(parts) == 2 && validMessageID(parts[0]) && validUnixTimestamp(parts[1]) { + t, _ := toUnixTimestamp(parts[1]) + return newSince(parts[0], t), nil + } + // ID, timestamp, duration if validMessageID(since) { return newSinceID(since), nil diff --git a/server/server_firebase_test.go b/server/server_firebase_test.go index 1fdd8a6e..6abe0db3 100644 --- a/server/server_firebase_test.go +++ b/server/server_firebase_test.go @@ -77,6 +77,7 @@ func TestToFirebaseMessage_Message_Normal_Allowed(t *testing.T) { require.Equal(t, map[string]string{ "id": m.ID, "time": fmt.Sprintf("%d", m.Time), + "updated": "0", "event": "message", "topic": "mytopic", "priority": "4", diff --git a/server/server_test.go b/server/server_test.go index 0827cc90..449c8d37 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -390,6 +390,69 @@ func TestServer_PublishAndPollSince(t *testing.T) { require.Equal(t, 40008, toHTTPError(t, response.Body.String()).Code) } +func TestServer_PublishUpdateAndPollSince(t *testing.T) { + s := newTestServer(t, newTestConfig(t)) + + // Initial PUT + response := request(t, s, "PUT", "/mytopic?t=atitle&tags=tag1,tag2&prio=high&click=https://google.com&attach=https://heckel.io", "test 1", nil) + message1 := toMessage(t, response.Body.String()) + require.Equal(t, int64(0), message1.Updated) + require.Equal(t, "test 1", message1.Message) + require.Equal(t, "atitle", message1.Title) + require.Equal(t, 4, message1.Priority) + require.Equal(t, []string{"tag1", "tag2"}, message1.Tags) + require.Equal(t, "https://google.com", message1.Click) + require.Equal(t, "https://heckel.io", message1.Attachment.URL) + + // Update + response = request(t, s, "PUT", "/mytopic/"+message1.ID+"?prio=low", "test 2", nil) + message2 := toMessage(t, response.Body.String()) + require.Equal(t, message1.ID, message2.ID) + require.True(t, message2.Updated > message1.Updated) + require.Equal(t, "test 2", message2.Message) // Updated + require.Equal(t, "atitle", message2.Title) + require.Equal(t, 2, message2.Priority) // Updated + require.Equal(t, []string{"tag1", "tag2"}, message2.Tags) + require.Equal(t, "https://google.com", message2.Click) + require.Equal(t, "https://heckel.io", message2.Attachment.URL) + + time.Sleep(1100 * time.Millisecond) + + // Another update + response = request(t, s, "PUT", "/mytopic/"+message1.ID+"?title=new+title", "test 3", nil) + message3 := toMessage(t, response.Body.String()) + require.True(t, message3.Updated > message2.Updated) + require.Equal(t, "test 3", message3.Message) // Updated + require.Equal(t, "new title", message3.Title) // Updated + + // Get all messages: Should be only one that was updated + since := "all" + response = request(t, s, "GET", "/mytopic/json?since="+since+"&poll=1", "", nil) + messages := toMessages(t, response.Body.String()) + require.Equal(t, 1, len(messages)) + require.Equal(t, message1.ID, messages[0].ID) + require.Equal(t, "test 3", messages[0].Message) + + // Get all messages since "message ID": Should be zero, since we know this message + since = message1.ID + response = request(t, s, "GET", "/mytopic/json?since="+since+"&poll=1", "", nil) + messages = toMessages(t, response.Body.String()) + require.Equal(t, 0, len(messages)) + + // Get all messages since "message ID" but with an older timestamp: Should be the latest updated message + since = fmt.Sprintf("%s/%d", message1.ID, message2.Updated) // We're missing an update + response = request(t, s, "GET", "/mytopic/json?since="+since+"&poll=1", "", nil) + messages = toMessages(t, response.Body.String()) + require.Equal(t, 1, len(messages)) + require.Equal(t, "test 3", messages[0].Message) + + // Get all messages since "message ID" with the current timestamp: No messages expected + since = fmt.Sprintf("%s/%d", message3.ID, message3.Updated) // We are up-to-date + response = request(t, s, "GET", "/mytopic/json?since="+since+"&poll=1", "", nil) + messages = toMessages(t, response.Body.String()) + require.Equal(t, 0, len(messages)) +} + func TestServer_PublishViaGET(t *testing.T) { s := newTestServer(t, newTestConfig(t)) diff --git a/server/types.go b/server/types.go index 0ab357f8..55dc1458 100644 --- a/server/types.go +++ b/server/types.go @@ -1,8 +1,10 @@ package server import ( + "errors" "heckel.io/ntfy/util" "net/http" + "strconv" "time" ) @@ -92,11 +94,31 @@ func validMessageID(s string) bool { return util.ValidRandomString(s, messageIDLength) } +func validUnixTimestamp(s string) bool { + _, err := toUnixTimestamp(s) + return err == nil +} + +func toUnixTimestamp(s string) (int64, error) { + u, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return 0, err + } + if u < 1000000000 || u > 3000000000 { // I know. It's practical. So relax ... + return 0, errors.New("invalid unix date") + } + return u, nil +} + type sinceMarker struct { time time.Time id string } +func newSince(id string, timestamp int64) sinceMarker { + return sinceMarker{time.Unix(timestamp, 0), id} +} + func newSinceTime(timestamp int64) sinceMarker { return sinceMarker{time.Unix(timestamp, 0), ""} } @@ -117,6 +139,10 @@ func (t sinceMarker) IsID() bool { return t.id != "" } +func (t sinceMarker) IsTime() bool { + return t.time.Unix() > 0 +} + func (t sinceMarker) Time() time.Time { return t.time }