mirror of
https://github.com/binwiederhier/ntfy.git
synced 2025-01-09 02:16:05 +01:00
Fix rate limiting behind proxy, make configurable
This commit is contained in:
parent
86a16e3944
commit
0170f673bd
5 changed files with 99 additions and 45 deletions
19
cmd/app.go
19
cmd/app.go
|
@ -22,8 +22,13 @@ func New() *cli.App {
|
|||
altsrc.NewStringFlag(&cli.StringFlag{Name: "firebase-key-file", Aliases: []string{"F"}, EnvVars: []string{"NTFY_FIREBASE_KEY_FILE"}, Usage: "Firebase credentials file; if set additionally publish to FCM topic"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"C"}, EnvVars: []string{"NTFY_CACHE_FILE"}, Usage: "cache file used for message caching"}),
|
||||
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "cache-duration", Aliases: []string{"b"}, EnvVars: []string{"NTFY_CACHE_DURATION"}, Value: config.DefaultCacheDuration, Usage: "buffer messages for this time to allow `since` requests"}),
|
||||
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "keepalive-interval", Aliases: []string{"k"}, EnvVars: []string{"NTFY_KEEPALIVE_INTERVAL"}, Value: config.DefaultKeepaliveInterval, Usage: "default interval of keepalive messages"}),
|
||||
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "manager-interval", Aliases: []string{"m"}, EnvVars: []string{"NTFY_MANAGER_INTERVAL"}, Value: config.DefaultManagerInterval, Usage: "default interval of for message pruning and stats printing"}),
|
||||
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "keepalive-interval", Aliases: []string{"k"}, EnvVars: []string{"NTFY_KEEPALIVE_INTERVAL"}, Value: config.DefaultKeepaliveInterval, Usage: "interval of keepalive messages"}),
|
||||
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "manager-interval", Aliases: []string{"m"}, EnvVars: []string{"NTFY_MANAGER_INTERVAL"}, Value: config.DefaultManagerInterval, Usage: "interval of for message pruning and stats printing"}),
|
||||
altsrc.NewIntFlag(&cli.IntFlag{Name: "global-topic-limit", Aliases: []string{"T"}, EnvVars: []string{"NTFY_GLOBAL_TOPIC_LIMIT"}, Value: config.DefaultGlobalTopicLimit, Usage: "total number of topics allowed"}),
|
||||
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-subscription-limit", Aliases: []string{"V"}, EnvVars: []string{"NTFY_VISITOR_SUBSCRIPTION_LIMIT"}, Value: config.DefaultVisitorSubscriptionLimit, Usage: "number of subscriptions per visitor"}),
|
||||
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-request-limit-burst", Aliases: []string{"B"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_BURST"}, Value: config.DefaultVisitorRequestLimitBurst, Usage: "initial limit of requests per visitor"}),
|
||||
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-request-limit-replenish", Aliases: []string{"R"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_REPLENISH"}, Value: config.DefaultVisitorRequestLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
|
||||
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}),
|
||||
}
|
||||
return &cli.App{
|
||||
Name: "ntfy",
|
||||
|
@ -50,6 +55,11 @@ func execRun(c *cli.Context) error {
|
|||
cacheDuration := c.Duration("cache-duration")
|
||||
keepaliveInterval := c.Duration("keepalive-interval")
|
||||
managerInterval := c.Duration("manager-interval")
|
||||
globalTopicLimit := c.Int("global-topic-limit")
|
||||
visitorSubscriptionLimit := c.Int("visitor-subscription-limit")
|
||||
visitorRequestLimitBurst := c.Int("visitor-request-limit-burst")
|
||||
visitorRequestLimitReplenish := c.Duration("visitor-request-limit-replenish")
|
||||
behindProxy := c.Bool("behind-proxy")
|
||||
|
||||
// Check values
|
||||
if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) {
|
||||
|
@ -69,6 +79,11 @@ func execRun(c *cli.Context) error {
|
|||
conf.CacheDuration = cacheDuration
|
||||
conf.KeepaliveInterval = keepaliveInterval
|
||||
conf.ManagerInterval = managerInterval
|
||||
conf.GlobalTopicLimit = globalTopicLimit
|
||||
conf.VisitorSubscriptionLimit = visitorSubscriptionLimit
|
||||
conf.VisitorRequestLimitBurst = visitorRequestLimitBurst
|
||||
conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish
|
||||
conf.BehindProxy = behindProxy
|
||||
s, err := server.New(conf)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"golang.org/x/time/rate"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -15,42 +14,44 @@ const (
|
|||
)
|
||||
|
||||
// Defines all the limits
|
||||
// - request limit: max number of PUT/GET/.. requests (here: 50 requests bucket, replenished at a rate of one per 10 seconds)
|
||||
// - global topic limit: max number of topics overall
|
||||
// - subscription limit: max number of subscriptions (active HTTP connections) per per-visitor/IP
|
||||
var (
|
||||
defaultGlobalTopicLimit = 5000
|
||||
defaultVisitorRequestLimit = rate.Every(10 * time.Second)
|
||||
defaultVisitorRequestLimitBurst = 60
|
||||
defaultVisitorSubscriptionLimit = 30
|
||||
// - per visistor request limit: max number of PUT/GET/.. requests (here: 60 requests bucket, replenished at a rate of one per 10 seconds)
|
||||
// - per visistor subscription limit: max number of subscriptions (active HTTP connections) per per-visitor/IP
|
||||
const (
|
||||
DefaultGlobalTopicLimit = 5000
|
||||
DefaultVisitorRequestLimitBurst = 60
|
||||
DefaultVisitorRequestLimitReplenish = 10 * time.Second
|
||||
DefaultVisitorSubscriptionLimit = 30
|
||||
)
|
||||
|
||||
// Config is the main config struct for the application. Use New to instantiate a default config struct.
|
||||
type Config struct {
|
||||
ListenHTTP string
|
||||
FirebaseKeyFile string
|
||||
CacheFile string
|
||||
CacheDuration time.Duration
|
||||
KeepaliveInterval time.Duration
|
||||
ManagerInterval time.Duration
|
||||
GlobalTopicLimit int
|
||||
VisitorRequestLimit rate.Limit
|
||||
VisitorRequestLimitBurst int
|
||||
VisitorSubscriptionLimit int
|
||||
ListenHTTP string
|
||||
FirebaseKeyFile string
|
||||
CacheFile string
|
||||
CacheDuration time.Duration
|
||||
KeepaliveInterval time.Duration
|
||||
ManagerInterval time.Duration
|
||||
GlobalTopicLimit int
|
||||
VisitorRequestLimitBurst int
|
||||
VisitorRequestLimitReplenish time.Duration
|
||||
VisitorSubscriptionLimit int
|
||||
BehindProxy bool
|
||||
}
|
||||
|
||||
// New instantiates a default new config
|
||||
func New(listenHTTP string) *Config {
|
||||
return &Config{
|
||||
ListenHTTP: listenHTTP,
|
||||
FirebaseKeyFile: "",
|
||||
CacheFile: "",
|
||||
CacheDuration: DefaultCacheDuration,
|
||||
KeepaliveInterval: DefaultKeepaliveInterval,
|
||||
ManagerInterval: DefaultManagerInterval,
|
||||
GlobalTopicLimit: defaultGlobalTopicLimit,
|
||||
VisitorRequestLimit: defaultVisitorRequestLimit,
|
||||
VisitorRequestLimitBurst: defaultVisitorRequestLimitBurst,
|
||||
VisitorSubscriptionLimit: defaultVisitorSubscriptionLimit,
|
||||
ListenHTTP: listenHTTP,
|
||||
FirebaseKeyFile: "",
|
||||
CacheFile: "",
|
||||
CacheDuration: DefaultCacheDuration,
|
||||
KeepaliveInterval: DefaultKeepaliveInterval,
|
||||
ManagerInterval: DefaultManagerInterval,
|
||||
GlobalTopicLimit: DefaultGlobalTopicLimit,
|
||||
VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst,
|
||||
VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish,
|
||||
VisitorSubscriptionLimit: DefaultVisitorSubscriptionLimit,
|
||||
BehindProxy: false,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,6 +25,30 @@
|
|||
#
|
||||
# keepalive-interval: 30s
|
||||
|
||||
# Interval in which the manager prunes old messages, deletes topics and prints the stats.
|
||||
# Interval in which the manager prunes old messages, deletes topics
|
||||
# and prints the stats.
|
||||
#
|
||||
# manager-interval: 1m
|
||||
|
||||
# Rate limiting: Total number of topics before the server rejects new topics.
|
||||
#
|
||||
# global-topic-limit: 5000
|
||||
|
||||
# Rate limiting: Number of subscriptions per visitor (IP address)
|
||||
#
|
||||
# visitor-subscription-limit: 30
|
||||
|
||||
# Rate limiting: Allowed GET/PUT/POST requests per second, per visitor:
|
||||
# - visitor-request-limit-burst is the initial bucket of requests each visitor has
|
||||
# - visitor-request-limit-replenish is the rate at which the bucket is refilled
|
||||
#
|
||||
# visitor-request-limit-burst: 60
|
||||
# visitor-request-limit-replenish: 10s
|
||||
|
||||
# If set, the X-Forwarded-For header is used to determine the visitor IP address
|
||||
# instead of the remote address of the connection.
|
||||
#
|
||||
# WARNING: If you are behind a proxy, you must set this, otherwise all visitors are rate limited
|
||||
# as if they are one.
|
||||
#
|
||||
# behind-proxy: false
|
||||
|
|
|
@ -159,24 +159,22 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
|
||||
v := s.visitor(r.RemoteAddr)
|
||||
if err := v.RequestAllowed(); err != nil {
|
||||
return err
|
||||
}
|
||||
if r.Method == http.MethodGet && r.URL.Path == "/" {
|
||||
return s.handleHome(w, r)
|
||||
} else if r.Method == http.MethodHead && r.URL.Path == "/" {
|
||||
return s.handleEmpty(w, r)
|
||||
} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
|
||||
return s.handleStatic(w, r)
|
||||
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicRegex.MatchString(r.URL.Path) {
|
||||
return s.handlePublish(w, r, v)
|
||||
} else if r.Method == http.MethodGet && jsonRegex.MatchString(r.URL.Path) {
|
||||
return s.handleSubscribeJSON(w, r, v)
|
||||
} else if r.Method == http.MethodGet && sseRegex.MatchString(r.URL.Path) {
|
||||
return s.handleSubscribeSSE(w, r, v)
|
||||
} else if r.Method == http.MethodGet && rawRegex.MatchString(r.URL.Path) {
|
||||
return s.handleSubscribeRaw(w, r, v)
|
||||
} else if r.Method == http.MethodOptions {
|
||||
return s.handleOptions(w, r)
|
||||
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicRegex.MatchString(r.URL.Path) {
|
||||
return s.withRateLimit(w, r, s.handlePublish)
|
||||
} else if r.Method == http.MethodGet && jsonRegex.MatchString(r.URL.Path) {
|
||||
return s.withRateLimit(w, r, s.handleSubscribeJSON)
|
||||
} else if r.Method == http.MethodGet && sseRegex.MatchString(r.URL.Path) {
|
||||
return s.withRateLimit(w, r, s.handleSubscribeSSE)
|
||||
} else if r.Method == http.MethodGet && rawRegex.MatchString(r.URL.Path) {
|
||||
return s.withRateLimit(w, r, s.handleSubscribeRaw)
|
||||
}
|
||||
return errHTTPNotFound
|
||||
}
|
||||
|
@ -186,6 +184,10 @@ func (s *Server) handleHome(w http.ResponseWriter, r *http.Request) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *Server) handleEmpty(w http.ResponseWriter, r *http.Request) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request) error {
|
||||
http.FileServer(http.FS(webStaticFs)).ServeHTTP(w, r)
|
||||
return nil
|
||||
|
@ -394,15 +396,27 @@ func (s *Server) updateStatsAndExpire() {
|
|||
s.messages, len(s.topics), subscribers, messages, len(s.visitors))
|
||||
}
|
||||
|
||||
func (s *Server) withRateLimit(w http.ResponseWriter, r *http.Request, handler func(w http.ResponseWriter, r *http.Request, v *visitor) error) error {
|
||||
v := s.visitor(r)
|
||||
if err := v.RequestAllowed(); err != nil {
|
||||
return err
|
||||
}
|
||||
return handler(w, r, v)
|
||||
}
|
||||
|
||||
// visitor creates or retrieves a rate.Limiter for the given visitor.
|
||||
// This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
|
||||
func (s *Server) visitor(remoteAddr string) *visitor {
|
||||
func (s *Server) visitor(r *http.Request) *visitor {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
remoteAddr := r.RemoteAddr
|
||||
ip, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
ip = remoteAddr // This should not happen in real life; only in tests.
|
||||
}
|
||||
if s.config.BehindProxy && r.Header.Get("X-Forwarded-For") != "" {
|
||||
ip = r.Header.Get("X-Forwarded-For")
|
||||
}
|
||||
v, exists := s.visitors[ip]
|
||||
if !exists {
|
||||
s.visitors[ip] = newVisitor(s.config)
|
||||
|
|
|
@ -24,7 +24,7 @@ type visitor struct {
|
|||
func newVisitor(conf *config.Config) *visitor {
|
||||
return &visitor{
|
||||
config: conf,
|
||||
limiter: rate.NewLimiter(conf.VisitorRequestLimit, conf.VisitorRequestLimitBurst),
|
||||
limiter: rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst),
|
||||
subscriptions: util.NewLimiter(int64(conf.VisitorSubscriptionLimit)),
|
||||
seen: time.Now(),
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue