diff --git a/server/server.go b/server/server.go index 0b3381c4..ec4ca670 100644 --- a/server/server.go +++ b/server/server.go @@ -571,7 +571,7 @@ func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error { } func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) { - vRate, ok := r.Context().Value("vRate").(*visitor) + vrate, ok := r.Context().Value("vRate").(*visitor) if !ok { return nil, errHTTPInternalError } @@ -579,8 +579,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes if !ok { return nil, errHTTPInternalError } - - if !vRate.MessageAllowed() { + if !vrate.MessageAllowed() { return nil, errHTTPTooManyRequestsLimitMessages } body, err := util.Peek(r.Body, s.config.MessageLimit) @@ -588,7 +587,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes return nil, err } m := newDefaultMessage(t.ID, "") - cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, vRate, m) + cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, vrate, m) if err != nil { return nil, err } @@ -607,7 +606,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes m.Message = emptyMessageBody } delayed := m.Time > time.Now().Unix() - ev := logvrm(vRate, r, m). + ev := logvrm(vrate, r, m). Tag(tagPublish). Fields(log.Context{ "message_delayed": delayed, @@ -625,7 +624,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes return nil, err } if s.firebaseClient != nil && firebase { - go s.sendToFirebase(vRate, m) + go s.sendToFirebase(vrate, m) } if s.smtpSender != nil && email != "" { go s.sendEmail(v, m, email) @@ -657,7 +656,6 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito if err != nil { return err } - return s.writeJSON(w, m) } @@ -766,7 +764,7 @@ func (s *Server) parsePublishParams(r *http.Request, vRate *visitor, m *message) if err != nil { return false, false, "", false, errHTTPBadRequestPriorityInvalid } - m.Tags = readCommaSeperatedParam(r, "x-tags", "tags", "tag", "ta") + m.Tags = readCommaSeparatedParam(r, "x-tags", "tags", "tag", "ta") delayStr := readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in") if delayStr != "" { if !cache { @@ -986,6 +984,12 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * } return nil } + for _, t := range topics { + subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) // temporarily do prefix as well + if subscriberRateLimited { + t.SetRateVisitor(v) + } + } w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset! if poll { @@ -995,8 +999,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * defer cancel() subscriberIDs := make([]int, 0) for _, t := range topics { - subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) // temporarily do prefix as well - subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel, subscriberRateLimited)) + subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel)) } defer func() { for i, subscriberID := range subscriberIDs { @@ -1122,14 +1125,19 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi } return conn.WriteJSON(msg) } + for _, t := range topics { + subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) // temporarily do prefix as well + if subscriberRateLimited { + t.SetRateVisitor(v) + } + } w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests if poll { return s.sendOldMessages(topics, since, scheduled, v, sub) } subscriberIDs := make([]int, 0) for _, t := range topics { - subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) // temporarily do prefix as well - subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel, subscriberRateLimited)) + subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel)) } defer func() { for i, subscriberID := range subscriberIDs { @@ -1161,7 +1169,7 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu if err != nil { return } - subscriberTopics = readCommaSeperatedParam(r, "subscriber-rate-limit-topics", "x-subscriber-rate-limit-topics", "srlt") + subscriberTopics = readCommaSeparatedParam(r, "subscriber-rate-limit-topics", "x-subscriber-rate-limit-topics", "srlt") return } diff --git a/server/server_middleware.go b/server/server_middleware.go index 1d2f734d..712223c4 100644 --- a/server/server_middleware.go +++ b/server/server_middleware.go @@ -26,8 +26,8 @@ func (s *Server) limitRequestsWithTopic(next handleFunc) handleFunc { return err } vrate := v - if topicCountsAgainst := t.Billee(); topicCountsAgainst != nil { - vrate = topicCountsAgainst + if rateVisitor := t.RateVisitor(); rateVisitor != nil { + vrate = rateVisitor } r = r.WithContext(context.WithValue(context.WithValue(r.Context(), "vRate", vrate), "topic", t)) diff --git a/server/server_test.go b/server/server_test.go index fe5a49fc..7f2665a0 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1894,15 +1894,17 @@ func TestServer_SubscriberRateLimiting(t *testing.T) { c.VisitorRequestLimitBurst = 3 s := newTestServer(t, c) + // "Register" visitor 1.2.3.4 to topic "subscriber1topic" as a rate limit visitor subscriber1Fn := func(r *http.Request) { r.RemoteAddr = "1.2.3.4" } rr := request(t, s, "GET", "/subscriber1topic/json?poll=1", "", map[string]string{ - "Subscriber-Rate-Limit-Topics": "mytopic1", + "Subscriber-Rate-Limit-Topics": "subscriber1topic", }, subscriber1Fn) require.Equal(t, 200, rr.Code) require.Equal(t, "", rr.Body.String()) + // "Register" visitor 8.7.7.1 to topic "upSUB2topic" as a rate limit visitor (implicitly via topic name) subscriber2Fn := func(r *http.Request) { r.RemoteAddr = "8.7.7.1" } @@ -1910,20 +1912,28 @@ func TestServer_SubscriberRateLimiting(t *testing.T) { require.Equal(t, 200, rr.Code) require.Equal(t, "", rr.Body.String()) - for i := 0; i < 3; i++ { + // Publish 2 messages to "subscriber1topic" as visitor 9.9.9.9. It'd be 3 normally, but the + // GET request before is also counted towards the request limiter. + for i := 0; i < 2; i++ { rr := request(t, s, "PUT", "/subscriber1topic", "some message", nil) require.Equal(t, 200, rr.Code) } rr = request(t, s, "PUT", "/subscriber1topic", "some message", nil) require.Equal(t, 429, rr.Code) - for i := 0; i < 3; i++ { + // Publish another 2 messages to "upSUB2topic" as visitor 9.9.9.9 + for i := 0; i < 2; i++ { rr := request(t, s, "PUT", "/upSUB2topic", "some message", nil) require.Equal(t, 200, rr.Code) // If we fail here, handlePublish is using the wrong visitor! } rr = request(t, s, "PUT", "/upSUB2topic", "some message", nil) require.Equal(t, 429, rr.Code) + // Hurray! At this point, visitor 9.9.9.9 has published 4 messages, even though + // VisitorRequestLimitBurst is 3. That means it's working. + + // Now let's confirm that so far we haven't used up any of visitor 9.9.9.9's request limiter + // by publishing another 3 requests from it. for i := 0; i < 3; i++ { rr := request(t, s, "PUT", "/some-other-topic", "some message", nil) require.Equal(t, 200, rr.Code) @@ -1959,18 +1969,18 @@ func newTestServer(t *testing.T, config *Config) *Server { func request(t *testing.T, s *Server, method, url, body string, headers map[string]string, fn ...func(r *http.Request)) *httptest.ResponseRecorder { rr := httptest.NewRecorder() - req, err := http.NewRequest(method, url, strings.NewReader(body)) + r, err := http.NewRequest(method, url, strings.NewReader(body)) if err != nil { t.Fatal(err) } - req.RemoteAddr = "9.9.9.9" // Used for tests + r.RemoteAddr = "9.9.9.9" // Used for tests for k, v := range headers { - req.Header.Set(k, v) + r.Header.Set(k, v) } for _, f := range fn { - f(req) + f(r) } - s.handle(rr, req) + s.handle(rr, r) return rr } diff --git a/server/topic.go b/server/topic.go index 85af0058..3b0cb542 100644 --- a/server/topic.go +++ b/server/topic.go @@ -19,10 +19,9 @@ type topic struct { } type topicSubscriber struct { - subscriber subscriber - visitor *visitor // User ID associated with this subscription, may be empty - cancel func() - subscriberRateLimit bool + subscriber subscriber + visitor *visitor // User ID associated with this subscription, may be empty + cancel func() } // subscriber is a function that is called for every new message on a topic @@ -37,39 +36,40 @@ func newTopic(id string) *topic { } // Subscribe subscribes to this topic -func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func(), subscriberRateLimit bool) int { +func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func()) int { t.mu.Lock() defer t.mu.Unlock() subscriberID := rand.Int() t.subscribers[subscriberID] = &topicSubscriber{ - visitor: visitor, // May be empty - subscriber: s, - cancel: cancel, - subscriberRateLimit: subscriberRateLimit, + visitor: visitor, // May be empty + subscriber: s, + cancel: cancel, } - - // if no subscriber is already handling the rate limit - if t.rateVisitor == nil && subscriberRateLimit { - t.rateVisitor = visitor - t.rateVisitorExpires = time.Time{} - } - return subscriberID } func (t *topic) Stale() bool { t.mu.Lock() defer t.mu.Unlock() - // if Time is initialized (not the zero value) and the expiry time has passed - if !t.rateVisitorExpires.IsZero() && t.rateVisitorExpires.Before(time.Now()) { + if t.rateVisitorExpires.Before(time.Now()) { t.rateVisitor = nil } return len(t.subscribers) == 0 && t.rateVisitor == nil } -func (t *topic) Billee() *visitor { - t.mu.RLock() - defer t.mu.RUnlock() +func (t *topic) SetRateVisitor(v *visitor) { + t.mu.Lock() + defer t.mu.Unlock() + t.rateVisitor = v + t.rateVisitorExpires = time.Now().Add(rateVisitorExpiryDuration) +} + +func (t *topic) RateVisitor() *visitor { + t.mu.Lock() + defer t.mu.Unlock() + if t.rateVisitorExpires.Before(time.Now()) { + t.rateVisitor = nil + } return t.rateVisitor } @@ -77,24 +77,7 @@ func (t *topic) Billee() *visitor { func (t *topic) Unsubscribe(id int) { t.mu.Lock() defer t.mu.Unlock() - - deletingSub := t.subscribers[id] delete(t.subscribers, id) - - // look for an active subscriber (in random order) that wants to handle the rate limit - for _, v := range t.subscribers { - if v.subscriberRateLimit { - t.rateVisitor = v.visitor - t.rateVisitorExpires = time.Time{} - return - } - } - - // if no active subscriber is found, count it towards the leaving subscriber - if deletingSub.subscriberRateLimit { - t.rateVisitor = deletingSub.visitor - t.rateVisitorExpires = time.Now().Add(rateVisitorExpiryDuration) - } } // Publish asynchronously publishes to all subscribers diff --git a/server/util.go b/server/util.go index 26d08543..8ec258fc 100644 --- a/server/util.go +++ b/server/util.go @@ -16,7 +16,7 @@ func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool { return value == "1" || value == "yes" || value == "true" } -func readCommaSeperatedParam(r *http.Request, names ...string) (params []string) { +func readCommaSeparatedParam(r *http.Request, names ...string) (params []string) { paramStr := readParam(r, names...) if paramStr != "" { params = make([]string, 0)