diff --git a/server/errors.go b/server/errors.go index 32c1b3b9..97af5472 100644 --- a/server/errors.go +++ b/server/errors.go @@ -50,7 +50,9 @@ var ( errHTTPBadRequestWebSocketsUpgradeHeaderMissing = &errHTTP{40016, http.StatusBadRequest, "invalid request: client not using the websocket protocol", "https://ntfy.sh/docs/subscribe/api/#websockets"} errHTTPBadRequestJSONInvalid = &errHTTP{40017, http.StatusBadRequest, "invalid request: request body must be message JSON", "https://ntfy.sh/docs/publish/#publish-as-json"} errHTTPBadRequestActionsInvalid = &errHTTP{40018, http.StatusBadRequest, "invalid request: actions invalid", "https://ntfy.sh/docs/publish/#action-buttons"} + errHTTPBadRequestDelayExpected = &errHTTP{40019, http.StatusBadRequest, "invalid request: expected delay in request, but none found", ""} errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""} + errHTTPNotFoundMessageDoesNotExist = &errHTTP{40402, http.StatusNotFound, "message not found", ""} errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication"} errHTTPForbidden = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication"} errHTTPEntityTooLargeAttachmentTooLarge = &errHTTP{41301, http.StatusRequestEntityTooLarge, "attachment too large, or bandwidth limit reached", "https://ntfy.sh/docs/publish/#limitations"} diff --git a/server/message_cache.go b/server/message_cache.go index b55c34ba..bdc36c4c 100644 --- a/server/message_cache.go +++ b/server/message_cache.go @@ -48,6 +48,11 @@ const ( INSERT INTO messages (mid, time, topic, message, title, priority, tags, click, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ` + updateMessageQuery = ` + UPDATE messages + SET time = ? + WHERE topic = ? AND mid = ? AND published = 0 + ` pruneMessagesQuery = `DELETE FROM messages WHERE time < ? AND published = 1` selectRowIDFromMessageID = `SELECT id FROM messages WHERE topic = ? AND mid = ?` selectMessagesSinceTimeQuery = ` @@ -80,6 +85,11 @@ const ( WHERE time <= ? AND published = 0 ORDER BY time, id ` + selectMessagesScheduledByTagOrID = ` + SELECT mid, time, topic, message, title, priority, tags, click, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_owner, encoding + FROM messages + WHERE topic = ? AND (tags LIKE ? OR mid = ?) AND published = 0 + ` updateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE mid = ?` selectMessagesCountQuery = `SELECT COUNT(*) FROM messages` selectMessageCountForTopicQuery = `SELECT COUNT(*) FROM messages WHERE topic = ?` @@ -219,8 +229,7 @@ func createMemoryFilename() string { func (c *messageCache) AddMessage(m *message) error { if m.Event != messageEvent { return errUnexpectedMessageType - } - if c.nop { + } else if c.nop { return nil } published := m.Time <= time.Now().Unix() @@ -266,6 +275,21 @@ func (c *messageCache) AddMessage(m *message) error { return err } +func (c *messageCache) UpdateMessage(m *message) error { + if m.Event != messageEvent { + return errUnexpectedMessageType + } else if c.nop { + return nil + } + _, err := c.db.Exec( + updateMessageQuery, + m.Time, + m.Topic, + m.ID, + ) + return err +} + func (c *messageCache) Messages(topic string, since sinceMarker, scheduled bool) ([]*message, error) { if since.IsNone() { return make([]*message, 0), nil @@ -409,6 +433,24 @@ func (c *messageCache) AttachmentsExpired() ([]string, error) { return ids, nil } +func (c *messageCache) MessagesScheduledByTagOrID(topic, selector string) ([]*message, error) { + rows, err := c.db.Query(selectMessagesScheduledByTagOrID, topic, "%"+selector+"%", selector) // Ugly string matching search first, later match exactly + if err != nil { + return nil, err + } + maybeMatchingMessages, err := readMessages(rows) + if err != nil { + return nil, err + } + messages := make([]*message, 0) + for _, m := range maybeMatchingMessages { + if util.InStringList(m.Tags, selector) || m.ID == selector { + messages = append(messages, m) + } + } + return messages, nil +} + func readMessages(rows *sql.Rows) ([]*message, error) { defer rows.Close() messages := make([]*message, 0) diff --git a/server/server.go b/server/server.go index 1a643c23..738fba0b 100644 --- a/server/server.go +++ b/server/server.go @@ -56,8 +56,9 @@ type handleFunc func(http.ResponseWriter, *http.Request, *visitor) error var ( // If changed, don't forget to update Android App and auth_sqlite.go - topicRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // No /! - topicPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`) // Regex must match JS & Android app! + topicRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // No /! + topicPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`) // Regex must match JS & Android app! + updatePathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/[^/]+$`) externalTopicPathRegex = regexp.MustCompile(`^/[^/]+\.[^/]+/[-_A-Za-z0-9]{1,64}$`) // Extended topic path, for web-app, e.g. /example.com/mytopic jsonPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/json$`) ssePathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/sse$`) @@ -287,6 +288,8 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v) } else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) { return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v) + } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && updatePathRegex.MatchString(r.URL.Path) { + return s.limitRequests(s.authWrite(s.handleUpdate))(w, r, v) } else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) { return s.limitRequests(s.authRead(s.handleSubscribeJSON))(w, r, v) } else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) { @@ -518,7 +521,7 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca m.Tags = append(m.Tags, strings.TrimSpace(s)) } } - delayStr := readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in") + delayStr := readDelayParam(r) if delayStr != "" { if !cache { return false, false, "", false, errHTTPBadRequestDelayNoCache @@ -526,15 +529,11 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca if email != "" { return false, false, "", false, errHTTPBadRequestDelayNoEmail // we cannot store the email address (yet) } - delay, err := util.ParseFutureTime(delayStr, time.Now()) + futureTime, err := s.parseDelay(delayStr) if err != nil { - return false, false, "", false, errHTTPBadRequestDelayCannotParse - } else if delay.Unix() < time.Now().Add(s.config.MinDelay).Unix() { - return false, false, "", false, errHTTPBadRequestDelayTooSmall - } else if delay.Unix() > time.Now().Add(s.config.MaxDelay).Unix() { - return false, false, "", false, errHTTPBadRequestDelayTooLarge + return false, false, "", false, err } - m.Time = delay.Unix() + m.Time = futureTime } actionsStr := readParam(r, "x-actions", "actions", "action") if actionsStr != "" { @@ -551,6 +550,22 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca return cache, firebase, email, unifiedpush, nil } +func readDelayParam(r *http.Request) string { + return readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in") +} + +func (s *Server) parseDelay(delayStr string) (int64, error) { + futureTime, err := util.ParseFutureTime(delayStr, time.Now()) + if err != nil { + return 0, errHTTPBadRequestDelayCannotParse + } else if futureTime.Unix() < time.Now().Add(s.config.MinDelay).Unix() { + return 0, errHTTPBadRequestDelayTooSmall + } else if futureTime.Unix() > time.Now().Add(s.config.MaxDelay).Unix() { + return 0, errHTTPBadRequestDelayTooLarge + } + return futureTime.Unix(), nil +} + // handlePublishBody consumes the PUT/POST body and decides whether the body is an attachment or the message. // // 1. curl -T somebinarydata.bin "ntfy.sh/mytopic?up=1" @@ -639,6 +654,46 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, return nil } +func (s *Server) handleUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error { + // Parse updatable params + parts := strings.Split(r.URL.Path, "/") + if len(parts) < 3 { + return errHTTPBadRequestTopicInvalid + } + t := parts[1] + selector := parts[2] + delayStr := readDelayParam(r) + if delayStr == "" { + return errHTTPBadRequestDelayExpected + } + futureTime, err := s.parseDelay(delayStr) + if err != nil { + return err + } + + // Update matching message(s) and print them + messages, err := s.messageCache.MessagesScheduledByTagOrID(t, selector) + if err != nil { + return err + } else if len(messages) == 0 { + return s.handlePublish(w, r, v) // If no messages found, publish a new one! + } + for _, m := range messages { + m.Time = futureTime + if err := s.messageCache.UpdateMessage(m); err != nil { + return err + } + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests + for _, m := range messages { + if err := json.NewEncoder(w).Encode(m); err != nil { + return err + } + } + return nil +} + func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *visitor) error { encoder := func(msg *message) (string, error) { var buf bytes.Buffer