diff --git a/server/server.go b/server/server.go index a1d97290..0b3381c4 100644 --- a/server/server.go +++ b/server/server.go @@ -9,6 +9,12 @@ import ( "encoding/json" "errors" "fmt" + "github.com/emersion/go-smtp" + "github.com/gorilla/websocket" + "golang.org/x/sync/errgroup" + "heckel.io/ntfy/log" + "heckel.io/ntfy/user" + "heckel.io/ntfy/util" "io" "net" "net/http" @@ -24,13 +30,6 @@ import ( "sync" "time" "unicode/utf8" - - "github.com/emersion/go-smtp" - "github.com/gorilla/websocket" - "golang.org/x/sync/errgroup" - "heckel.io/ntfy/log" - "heckel.io/ntfy/user" - "heckel.io/ntfy/util" ) // Server is the main server, providing the UI and API for ntfy @@ -105,15 +104,15 @@ var ( ) const ( - firebaseControlTopic = "~control" // See Android if changed - firebasePollTopic = "~poll" // See iOS if changed - emptyMessageBody = "triggered" // Used if message body is empty - newMessageBody = "New message" // Used in poll requests as generic message - defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment - encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages - jsonBodyBytesLimit = 16384 - subscriberBilledTopicPrefix = "up_" - subscriberBilledValidity = 12 * time.Hour + firebaseControlTopic = "~control" // See Android if changed + firebasePollTopic = "~poll" // See iOS if changed + emptyMessageBody = "triggered" // Used if message body is empty + newMessageBody = "New message" // Used in poll requests as generic message + defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment + encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages + jsonBodyBytesLimit = 16384 + unifiedPushTopicPrefix = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber + rateVisitorExpiryDuration = 12 * time.Hour ) // WebSocket constants @@ -996,7 +995,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * defer cancel() subscriberIDs := make([]int, 0) for _, t := range topics { - subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, subscriberBilledTopicPrefix) // temporarily do prefix as well + subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) // temporarily do prefix as well subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel, subscriberRateLimited)) } defer func() { @@ -1129,7 +1128,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi } subscriberIDs := make([]int, 0) for _, t := range topics { - subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, subscriberBilledTopicPrefix) // temporarily do prefix as well + subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) // temporarily do prefix as well subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel, subscriberRateLimited)) } defer func() { @@ -1162,7 +1161,6 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu if err != nil { return } - subscriberTopics = readCommaSeperatedParam(r, "subscriber-rate-limit-topics", "x-subscriber-rate-limit-topics", "srlt") return } diff --git a/server/server_manager.go b/server/server_manager.go index 114b7c85..6d306354 100644 --- a/server/server_manager.go +++ b/server/server_manager.go @@ -40,7 +40,7 @@ func (s *Server) execManager() { if ev.IsTrace() { expiryMessage := "" if subs == 0 { - expiryTime := time.Until(t.vRateExpires) + expiryTime := time.Until(t.rateVisitorExpires) expiryMessage = ", expires in " + expiryTime.String() } ev.Trace("- topic %s: %d subscribers%s", t.ID, subs, expiryMessage) diff --git a/server/server_middleware.go b/server/server_middleware.go index 1130b08e..1d2f734d 100644 --- a/server/server_middleware.go +++ b/server/server_middleware.go @@ -25,15 +25,15 @@ func (s *Server) limitRequestsWithTopic(next handleFunc) handleFunc { if err != nil { return err } - vRate := v + vrate := v if topicCountsAgainst := t.Billee(); topicCountsAgainst != nil { - vRate = topicCountsAgainst + vrate = topicCountsAgainst } - r = r.WithContext(context.WithValue(context.WithValue(r.Context(), "vRate", vRate), "topic", t)) + r = r.WithContext(context.WithValue(context.WithValue(r.Context(), "vRate", vrate), "topic", t)) if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) { return next(w, r, v) - } else if !vRate.RequestAllowed() { + } else if !vrate.RequestAllowed() { return errHTTPTooManyRequestsLimitRequests } return next(w, r, v) diff --git a/server/server_test.go b/server/server_test.go index 4abff399..fe5a49fc 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1889,6 +1889,49 @@ func TestServer_AnonymousUser_And_NonTierUser_Are_Same_Visitor(t *testing.T) { require.Equal(t, int64(2), account.Stats.Messages) } +func TestServer_SubscriberRateLimiting(t *testing.T) { + c := newTestConfigWithAuthFile(t) + c.VisitorRequestLimitBurst = 3 + s := newTestServer(t, c) + + subscriber1Fn := func(r *http.Request) { + r.RemoteAddr = "1.2.3.4" + } + rr := request(t, s, "GET", "/subscriber1topic/json?poll=1", "", map[string]string{ + "Subscriber-Rate-Limit-Topics": "mytopic1", + }, subscriber1Fn) + require.Equal(t, 200, rr.Code) + require.Equal(t, "", rr.Body.String()) + + subscriber2Fn := func(r *http.Request) { + r.RemoteAddr = "8.7.7.1" + } + rr = request(t, s, "GET", "/upSUB2topic/json?poll=1", "", nil, subscriber2Fn) + require.Equal(t, 200, rr.Code) + require.Equal(t, "", rr.Body.String()) + + for i := 0; i < 3; i++ { + rr := request(t, s, "PUT", "/subscriber1topic", "some message", nil) + require.Equal(t, 200, rr.Code) + } + rr = request(t, s, "PUT", "/subscriber1topic", "some message", nil) + require.Equal(t, 429, rr.Code) + + for i := 0; i < 3; i++ { + rr := request(t, s, "PUT", "/upSUB2topic", "some message", nil) + require.Equal(t, 200, rr.Code) // If we fail here, handlePublish is using the wrong visitor! + } + rr = request(t, s, "PUT", "/upSUB2topic", "some message", nil) + require.Equal(t, 429, rr.Code) + + for i := 0; i < 3; i++ { + rr := request(t, s, "PUT", "/some-other-topic", "some message", nil) + require.Equal(t, 200, rr.Code) + } + rr = request(t, s, "PUT", "/some-other-topic", "some message", nil) + require.Equal(t, 429, rr.Code) +} + func newTestConfig(t *testing.T) *Config { conf := NewConfig() conf.BaseURL = "http://127.0.0.1:12345" @@ -1914,7 +1957,7 @@ func newTestServer(t *testing.T, config *Config) *Server { return server } -func request(t *testing.T, s *Server, method, url, body string, headers map[string]string) *httptest.ResponseRecorder { +func request(t *testing.T, s *Server, method, url, body string, headers map[string]string, fn ...func(r *http.Request)) *httptest.ResponseRecorder { rr := httptest.NewRecorder() req, err := http.NewRequest(method, url, strings.NewReader(body)) if err != nil { @@ -1924,6 +1967,9 @@ func request(t *testing.T, s *Server, method, url, body string, headers map[stri for k, v := range headers { req.Header.Set(k, v) } + for _, f := range fn { + f(req) + } s.handle(rr, req) return rr } diff --git a/server/topic.go b/server/topic.go index 31545340..85af0058 100644 --- a/server/topic.go +++ b/server/topic.go @@ -11,11 +11,11 @@ import ( // topic represents a channel to which subscribers can subscribe, and publishers // can publish a message type topic struct { - ID string - subscribers map[int]*topicSubscriber - vRate *visitor - vRateExpires time.Time - mu sync.Mutex + ID string + subscribers map[int]*topicSubscriber + rateVisitor *visitor + rateVisitorExpires time.Time + mu sync.RWMutex } type topicSubscriber struct { @@ -49,9 +49,9 @@ func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func(), subscri } // if no subscriber is already handling the rate limit - if t.vRate == nil && subscriberRateLimit { - t.vRate = visitor - t.vRateExpires = time.Time{} + if t.rateVisitor == nil && subscriberRateLimit { + t.rateVisitor = visitor + t.rateVisitorExpires = time.Time{} } return subscriberID @@ -61,16 +61,16 @@ func (t *topic) Stale() bool { t.mu.Lock() defer t.mu.Unlock() // if Time is initialized (not the zero value) and the expiry time has passed - if !t.vRateExpires.IsZero() && t.vRateExpires.Before(time.Now()) { - t.vRate = nil + if !t.rateVisitorExpires.IsZero() && t.rateVisitorExpires.Before(time.Now()) { + t.rateVisitor = nil } - return len(t.subscribers) == 0 && t.vRate == nil + return len(t.subscribers) == 0 && t.rateVisitor == nil } func (t *topic) Billee() *visitor { - t.mu.Lock() - defer t.mu.Unlock() - return t.vRate + t.mu.RLock() + defer t.mu.RUnlock() + return t.rateVisitor } // Unsubscribe removes the subscription from the list of subscribers @@ -84,16 +84,16 @@ func (t *topic) Unsubscribe(id int) { // look for an active subscriber (in random order) that wants to handle the rate limit for _, v := range t.subscribers { if v.subscriberRateLimit { - t.vRate = v.visitor - t.vRateExpires = time.Time{} + t.rateVisitor = v.visitor + t.rateVisitorExpires = time.Time{} return } } // if no active subscriber is found, count it towards the leaving subscriber if deletingSub.subscriberRateLimit { - t.vRate = deletingSub.visitor - t.vRateExpires = time.Now().Add(subscriberBilledValidity) + t.rateVisitor = deletingSub.visitor + t.rateVisitorExpires = time.Now().Add(rateVisitorExpiryDuration) } } @@ -123,8 +123,8 @@ func (t *topic) Publish(v *visitor, m *message) error { // SubscribersCount returns the number of subscribers to this topic func (t *topic) SubscribersCount() int { - t.mu.Lock() - defer t.mu.Unlock() + t.mu.RLock() + defer t.mu.RUnlock() return len(t.subscribers) }