diff --git a/server/errors.go b/server/errors.go index f50a4df9..819f972e 100644 --- a/server/errors.go +++ b/server/errors.go @@ -3,8 +3,9 @@ package server import ( "encoding/json" "fmt" - "heckel.io/ntfy/log" "net/http" + + "heckel.io/ntfy/log" ) // errHTTP is a generic HTTP error for any non-200 HTTP error @@ -92,5 +93,4 @@ var ( errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""} errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", ""} errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/"} - errHTTPWontStoreMessage = &errHTTP{50701, http.StatusInsufficientStorage, "topic is inactive; no device available to recieve message", ""} ) diff --git a/server/server.go b/server/server.go index aa261d5e..a5307964 100644 --- a/server/server.go +++ b/server/server.go @@ -9,12 +9,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/emersion/go-smtp" - "github.com/gorilla/websocket" - "golang.org/x/sync/errgroup" - "heckel.io/ntfy/log" - "heckel.io/ntfy/user" - "heckel.io/ntfy/util" "io" "net" "net/http" @@ -30,6 +24,13 @@ import ( "sync" "time" "unicode/utf8" + + "github.com/emersion/go-smtp" + "github.com/gorilla/websocket" + "golang.org/x/sync/errgroup" + "heckel.io/ntfy/log" + "heckel.io/ntfy/user" + "heckel.io/ntfy/util" ) /* @@ -605,23 +606,23 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes if err != nil { return nil, err } - v_old := v - if strings.HasPrefix(t.ID, subscriberBilledTopicPrefix) { - v = t.getBillee() - if v == nil { - return nil, errHTTPWontStoreMessage - } + vRate := v + if topicCountsAgainst := t.Billee(); topicCountsAgainst != nil { + vRate = topicCountsAgainst } - if !v.MessageAllowed() { - return nil, errHTTPTooManyRequestsLimitMessages + if !vRate.MessageAllowed() { + vRate = v + if !v.MessageAllowed() { + return nil, errHTTPTooManyRequestsLimitMessages + } } body, err := util.Peek(r.Body, s.config.MessageLimit) if err != nil { return nil, err } m := newDefaultMessage(t.ID, "") - cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, v, m) + cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, vRate, m) if err != nil { return nil, err } @@ -630,7 +631,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes } m.Sender = v.IP() m.User = v.MaybeUserID() - m.Expires = time.Unix(m.Time, 0).Add(v.Limits().MessageExpiryDuration).Unix() + m.Expires = time.Unix(m.Time, 0).Add(vRate.Limits().MessageExpiryDuration).Unix() if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil { return nil, err } @@ -638,18 +639,18 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes m.Message = emptyMessageBody } delayed := m.Time > time.Now().Unix() - logvrm(v, r, m). + logvrm(vRate, r, m). Tag(tagPublish). Fields(log.Context{ - "message_delayed": delayed, - "message_firebase": firebase, - "message_unifiedpush": unifiedpush, - "message_email": email, + "message_delayed": delayed, + "message_firebase": firebase, + "message_unifiedpush": unifiedpush, + "message_email": email, + "message_subscriber_rate_limited": vRate != v, }). Debug("Received message") - //Where should I log the original visitor vs the billing visitor if log.IsTrace() { - logvrm(v_old, r, m). + logvrm(vRate, r, m). Tag(tagPublish). Field("message_body", util.MaybeMarshalJSON(m)). Trace("Message body") @@ -659,7 +660,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes return nil, err } if s.firebaseClient != nil && firebase { - go s.sendToFirebase(v, m) + go s.sendToFirebase(vRate, m) } if s.smtpSender != nil && email != "" { go s.sendEmail(v, m, email) @@ -745,7 +746,7 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) { } } -func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) { +func (s *Server) parsePublishParams(r *http.Request, vRate *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) { cache = readBoolParam(r, true, "x-cache", "cache") firebase = readBoolParam(r, true, "x-firebase", "firebase") m.Title = readParam(r, "x-title", "title", "t") @@ -785,7 +786,7 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca } email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e") if email != "" { - if !v.EmailAllowed() { + if !vRate.EmailAllowed() { return false, false, "", false, errHTTPTooManyRequestsLimitEmails } } @@ -800,13 +801,7 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca if err != nil { return false, false, "", false, errHTTPBadRequestPriorityInvalid } - tagsStr := readParam(r, "x-tags", "tags", "tag", "ta") - if tagsStr != "" { - m.Tags = make([]string, 0) - for _, s := range util.SplitNoEmpty(tagsStr, ",") { - m.Tags = append(m.Tags, strings.TrimSpace(s)) - } - } + m.Tags = readCommaSeperatedParam(r, "x-tags", "tags", "tag", "ta") delayStr := readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in") if delayStr != "" { if !cache { @@ -996,7 +991,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * if err != nil { return err } - poll, since, scheduled, filters, err := parseSubscribeParams(r) + poll, since, scheduled, filters, subscriberRateTopics, err := parseSubscribeParams(r) if err != nil { return err } @@ -1035,7 +1030,8 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * defer cancel() subscriberIDs := make([]int, 0) for _, t := range topics { - subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel)) + subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, subscriberBilledTopicPrefix) // temporarily do prefix as well + subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel, subscriberRateLimited)) } defer func() { for i, subscriberID := range subscriberIDs { @@ -1078,7 +1074,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi if err != nil { return err } - poll, since, scheduled, filters, err := parseSubscribeParams(r) + poll, since, scheduled, filters, subscriberRateTopics, err := parseSubscribeParams(r) if err != nil { return err } @@ -1167,7 +1163,8 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi } subscriberIDs := make([]int, 0) for _, t := range topics { - subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel)) + subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, subscriberBilledTopicPrefix) // temporarily do prefix as well + subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel, subscriberRateLimited)) } defer func() { for i, subscriberID := range subscriberIDs { @@ -1188,7 +1185,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi return err } -func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, err error) { +func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, subscriberTopics []string, err error) { poll = readBoolParam(r, false, "x-poll", "poll", "po") scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched") since, err = parseSince(r, poll) @@ -1199,6 +1196,8 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu if err != nil { return } + + subscriberTopics = readCommaSeperatedParam(r, "subscriber-rate-limit-topics", "x-subscriber-rate-limit-topics", "srlt") return } diff --git a/server/topic.go b/server/topic.go index 6911ec0b..19501a87 100644 --- a/server/topic.go +++ b/server/topic.go @@ -19,9 +19,10 @@ type topic struct { } type topicSubscriber struct { - subscriber subscriber - visitor *visitor // User ID associated with this subscription, may be empty - cancel func() + subscriber subscriber + visitor *visitor // User ID associated with this subscription, may be empty + cancel func() + subscriberRateLimit bool } // subscriber is a function that is called for every new message on a topic @@ -36,31 +37,36 @@ func newTopic(id string) *topic { } // Subscribe subscribes to this topic -func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func()) int { +func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func(), subscriberRateLimit bool) int { t.mu.Lock() defer t.mu.Unlock() subscriberID := rand.Int() t.subscribers[subscriberID] = &topicSubscriber{ - visitor: visitor, // May be empty - subscriber: s, - cancel: cancel, + visitor: visitor, // May be empty + subscriber: s, + cancel: cancel, + subscriberRateLimit: subscriberRateLimit, } + + // if no subscriber is already handling the rate limit + if t.lastVisitor == nil && subscriberRateLimit { + t.lastVisitor = visitor + t.lastVisitorExpires = time.Time{} + } + return subscriberID } func (t *topic) Stale() bool { - return t.getBillee() == nil -} - -func (t *topic) getBillee() *visitor { - for _, this_subscriber := range t.subscribers { - return this_subscriber.visitor - } - if t.lastVisitor != nil && t.lastVisitorExpires.After(time.Now()) { + // if Time is initialized (not the zero value) and the expiry time has passed + if !t.lastVisitorExpires.IsZero() && t.lastVisitorExpires.Before(time.Now()) { t.lastVisitor = nil } - return t.lastVisitor + return len(t.subscribers) == 0 && t.lastVisitor == nil +} +func (t *topic) Billee() *visitor { + return t.lastVisitor } // Unsubscribe removes the subscription from the list of subscribers @@ -68,11 +74,23 @@ func (t *topic) Unsubscribe(id int) { t.mu.Lock() defer t.mu.Unlock() - if len(t.subscribers) == 1 { - t.lastVisitor = t.subscribers[id].visitor + deletingSub := t.subscribers[id] + delete(t.subscribers, id) + + // look for an active subscriber (in random order) that wants to handle the rate limit + for _, v := range t.subscribers { + if v.subscriberRateLimit { + t.lastVisitor = v.visitor + t.lastVisitorExpires = time.Time{} + return + } + } + + // if no active subscriber is found, count it towards the leaving subscriber + if deletingSub.subscriberRateLimit { + t.lastVisitor = deletingSub.visitor t.lastVisitorExpires = time.Now().Add(subscriberBilledValidity) } - delete(t.subscribers, id) } // Publish asynchronously publishes to all subscribers diff --git a/server/util.go b/server/util.go index 8fbfaefa..048e2f93 100644 --- a/server/util.go +++ b/server/util.go @@ -1,12 +1,13 @@ package server import ( - "heckel.io/ntfy/log" - "heckel.io/ntfy/util" "io" "net/http" "net/netip" "strings" + + "heckel.io/ntfy/log" + "heckel.io/ntfy/util" ) func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool { @@ -17,6 +18,17 @@ func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool { return value == "1" || value == "yes" || value == "true" } +func readCommaSeperatedParam(r *http.Request, names ...string) (params []string) { + paramStr := readParam(r, names...) + if paramStr != "" { + params = make([]string, 0) + for _, s := range util.SplitNoEmpty(paramStr, ",") { + params = append(params, strings.TrimSpace(s)) + } + } + return params +} + func readParam(r *http.Request, names ...string) string { value := readHeaderParam(r, names...) if value != "" { @@ -35,6 +47,13 @@ func readHeaderParam(r *http.Request, names ...string) string { return "" } +func readHeaderParamValues(r *http.Request, names ...string) (values []string) { + for _, name := range names { + values = append(values, r.Header.Values(name)...) + } + return +} + func readQueryParam(r *http.Request, names ...string) string { for _, name := range names { value := r.URL.Query().Get(strings.ToLower(name))