From fa7a45902fdc22268952a5fc85d32966e71a1091 Mon Sep 17 00:00:00 2001 From: Philipp Heckel Date: Mon, 1 Nov 2021 15:21:38 -0400 Subject: [PATCH] Subscription limit --- config/config.go | 15 ++++++----- server/server.go | 54 +++++++++++++++++---------------------- server/visitor.go | 65 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 97 insertions(+), 37 deletions(-) create mode 100644 server/visitor.go diff --git a/config/config.go b/config/config.go index 6e3f242a..2d3db1c7 100644 --- a/config/config.go +++ b/config/config.go @@ -17,8 +17,9 @@ const ( // Defines the max number of requests, here: // 50 requests bucket, replenished at a rate of 1 per second var ( - defaultLimit = rate.Every(time.Second) - defaultLimitBurst = 50 + defaultRequestLimit = rate.Every(time.Second) + defaultRequestLimitBurst = 50 + defaultSubscriptionLimit = 30 // per visitor ) // Config is the main config struct for the application. Use New to instantiate a default config struct. @@ -28,8 +29,9 @@ type Config struct { MessageBufferDuration time.Duration KeepaliveInterval time.Duration ManagerInterval time.Duration - Limit rate.Limit - LimitBurst int + RequestLimit rate.Limit + RequestLimitBurst int + SubscriptionLimit int } // New instantiates a default new config @@ -40,7 +42,8 @@ func New(listenHTTP string) *Config { MessageBufferDuration: DefaultMessageBufferDuration, KeepaliveInterval: DefaultKeepaliveInterval, ManagerInterval: DefaultManagerInterval, - Limit: defaultLimit, - LimitBurst: defaultLimitBurst, + RequestLimit: defaultRequestLimit, + RequestLimitBurst: defaultRequestLimitBurst, + SubscriptionLimit: defaultSubscriptionLimit, } } diff --git a/server/server.go b/server/server.go index 907c0da1..6811a62e 100644 --- a/server/server.go +++ b/server/server.go @@ -9,7 +9,6 @@ import ( firebase "firebase.google.com/go" "firebase.google.com/go/messaging" "fmt" - "golang.org/x/time/rate" "google.golang.org/api/option" "heckel.io/ntfy/config" "io" @@ -23,9 +22,8 @@ import ( "time" ) -// TODO add "max connections open" limit // TODO add "max messages in a topic" limit -// TODO add "max topics" limit +// TODO implement persistence // Server is the main server type Server struct { @@ -37,12 +35,6 @@ type Server struct { mu sync.Mutex } -// visitor represents an API user, and its associated rate.Limiter used for rate limiting -type visitor struct { - limiter *rate.Limiter - seen time.Time -} - // errHTTP is a generic HTTP error for any non-200 HTTP error type errHTTP struct { Code int @@ -54,8 +46,7 @@ func (e errHTTP) Error() string { } const ( - messageLimit = 1024 - visitorExpungeAfter = 30 * time.Minute + messageLimit = 1024 ) var ( @@ -147,8 +138,8 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) { func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error { v := s.visitor(r.RemoteAddr) - if !v.limiter.Allow() { - return errHTTPTooManyRequests + if err := v.RequestAllowed(); err != nil { + return err } if r.Method == http.MethodGet && r.URL.Path == "/" { return s.handleHome(w, r) @@ -157,11 +148,11 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error { } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicRegex.MatchString(r.URL.Path) { return s.handlePublish(w, r) } else if r.Method == http.MethodGet && jsonRegex.MatchString(r.URL.Path) { - return s.handleSubscribeJSON(w, r) + return s.handleSubscribeJSON(w, r, v) } else if r.Method == http.MethodGet && sseRegex.MatchString(r.URL.Path) { - return s.handleSubscribeSSE(w, r) + return s.handleSubscribeSSE(w, r, v) } else if r.Method == http.MethodGet && rawRegex.MatchString(r.URL.Path) { - return s.handleSubscribeRaw(w, r) + return s.handleSubscribeRaw(w, r, v) } else if r.Method == http.MethodOptions { return s.handleOptions(w, r) } @@ -195,7 +186,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request) error { return nil } -func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request) error { +func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *visitor) error { encoder := func(msg *message) (string, error) { var buf bytes.Buffer if err := json.NewEncoder(&buf).Encode(&msg); err != nil { @@ -203,10 +194,10 @@ func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request) err } return buf.String(), nil } - return s.handleSubscribe(w, r, "json", "application/stream+json", encoder) + return s.handleSubscribe(w, r, v, "json", "application/stream+json", encoder) } -func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request) error { +func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *visitor) error { encoder := func(msg *message) (string, error) { var buf bytes.Buffer if err := json.NewEncoder(&buf).Encode(&msg); err != nil { @@ -217,20 +208,24 @@ func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request) erro } return fmt.Sprintf("data: %s\n", buf.String()), nil } - return s.handleSubscribe(w, r, "sse", "text/event-stream", encoder) + return s.handleSubscribe(w, r, v, "sse", "text/event-stream", encoder) } -func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request) error { +func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *visitor) error { encoder := func(msg *message) (string, error) { if msg.Event == "" { // only handle default events return strings.ReplaceAll(msg.Message, "\n", " ") + "\n", nil } return "\n", nil // "keepalive" and "open" events just send an empty line } - return s.handleSubscribe(w, r, "raw", "text/plain", encoder) + return s.handleSubscribe(w, r, v, "raw", "text/plain", encoder) } -func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, format string, contentType string, encoder messageEncoder) error { +func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visitor, format string, contentType string, encoder messageEncoder) error { + if err := v.AddSubscription(); err != nil { + return err + } + defer v.RemoveSubscription() t := s.createTopic(strings.TrimSuffix(r.URL.Path[1:], "/"+format)) // Hack since, err := parseSince(r) if err != nil { @@ -270,6 +265,7 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, format case <-r.Context().Done(): return nil case <-time.After(s.config.KeepaliveInterval): + v.Keepalive() if err := sub(newKeepaliveMessage(t.id)); err != nil { // Send keepalive message return err } @@ -326,12 +322,12 @@ func (s *Server) updateStatsAndExpire() { // Expire visitors from rate visitors map for ip, v := range s.visitors { - if time.Since(v.seen) > visitorExpungeAfter { + if v.Stale() { delete(s.visitors, ip) } } - // Prune old messages, remove topics without subscribers + // Prune old messages, remove subscriptions without subscribers for _, t := range s.topics { t.Prune(s.config.MessageBufferDuration) subs, msgs := t.Stats() @@ -362,12 +358,8 @@ func (s *Server) visitor(remoteAddr string) *visitor { } v, exists := s.visitors[ip] if !exists { - v = &visitor{ - rate.NewLimiter(s.config.Limit, s.config.LimitBurst), - time.Now(), - } - s.visitors[ip] = v - return v + s.visitors[ip] = newVisitor(s.config) + return s.visitors[ip] } v.seen = time.Now() return v diff --git a/server/visitor.go b/server/visitor.go new file mode 100644 index 00000000..06ee32d6 --- /dev/null +++ b/server/visitor.go @@ -0,0 +1,65 @@ +package server + +import ( + "golang.org/x/time/rate" + "heckel.io/ntfy/config" + "sync" + "time" +) + +const ( + visitorExpungeAfter = 30 * time.Minute +) + +// visitor represents an API user, and its associated rate.Limiter used for rate limiting +type visitor struct { + config *config.Config + limiter *rate.Limiter + subscriptions int + seen time.Time + mu sync.Mutex +} + +func newVisitor(conf *config.Config) *visitor { + return &visitor{ + config: conf, + limiter: rate.NewLimiter(conf.RequestLimit, conf.RequestLimitBurst), + seen: time.Now(), + } +} + +func (v *visitor) RequestAllowed() error { + if !v.limiter.Allow() { + return errHTTPTooManyRequests + } + return nil +} + +func (v *visitor) AddSubscription() error { + v.mu.Lock() + defer v.mu.Unlock() + if v.subscriptions >= v.config.SubscriptionLimit { + return errHTTPTooManyRequests + } + v.subscriptions++ + return nil +} + +func (v *visitor) RemoveSubscription() { + v.mu.Lock() + defer v.mu.Unlock() + v.subscriptions-- +} + +func (v *visitor) Keepalive() { + v.mu.Lock() + defer v.mu.Unlock() + v.seen = time.Now() +} + +func (v *visitor) Stale() bool { + v.mu.Lock() + defer v.mu.Unlock() + v.seen = time.Now() + return time.Since(v.seen) > visitorExpungeAfter +}