diff --git a/cmd/serve.go b/cmd/serve.go index 7435189d..6bc9aa90 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -9,6 +9,7 @@ import ( "heckel.io/ntfy/util" "log" "math" + "net" "strings" "time" ) @@ -45,6 +46,7 @@ var flagsServe = []cli.Flag{ altsrc.NewStringFlag(&cli.StringFlag{Name: "visitor-attachment-daily-bandwidth-limit", EnvVars: []string{"NTFY_VISITOR_ATTACHMENT_DAILY_BANDWIDTH_LIMIT"}, Value: "500M", Usage: "total daily attachment download/upload bandwidth limit per visitor"}), altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-request-limit-burst", EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_BURST"}, Value: server.DefaultVisitorRequestLimitBurst, Usage: "initial limit of requests per visitor"}), altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-request-limit-replenish", EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_REPLENISH"}, Value: server.DefaultVisitorRequestLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}), + altsrc.NewStringFlag(&cli.StringFlag{Name: "visitor-request-limit-exempt-hosts", EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_EXEMPT_HOSTS"}, Value: "", Usage: "hostnames and/or IP addresses of hosts that will be exempt from the visitor request limit"}), altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}), altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-email-limit-replenish", EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: server.DefaultVisitorEmailLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}), altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}), @@ -104,6 +106,7 @@ func execServe(c *cli.Context) error { visitorAttachmentDailyBandwidthLimitStr := c.String("visitor-attachment-daily-bandwidth-limit") visitorRequestLimitBurst := c.Int("visitor-request-limit-burst") visitorRequestLimitReplenish := c.Duration("visitor-request-limit-replenish") + visitorRequestLimitExemptHosts := util.SplitNoEmpty(c.String("visitor-request-limit-exempt-hosts"), ",") visitorEmailLimitBurst := c.Int("visitor-email-limit-burst") visitorEmailLimitReplenish := c.Duration("visitor-email-limit-replenish") behindProxy := c.Bool("behind-proxy") @@ -164,6 +167,19 @@ func execServe(c *cli.Context) error { return fmt.Errorf("config option visitor-attachment-daily-bandwidth-limit must be lower than %d", math.MaxInt) } + // Resolve hosts + visitorRequestLimitExemptIPs := make([]string, 0) + for _, host := range visitorRequestLimitExemptHosts { + ips, err := net.LookupIP(host) + if err != nil { + log.Printf("cannot resolve host %s: %s, ignoring visitor request exemption", host, err.Error()) + continue + } + for _, ip := range ips { + visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ip.String()) + } + } + // Run server conf := server.NewConfig() conf.BaseURL = baseURL @@ -197,6 +213,7 @@ func execServe(c *cli.Context) error { conf.VisitorAttachmentDailyBandwidthLimit = int(visitorAttachmentDailyBandwidthLimit) conf.VisitorRequestLimitBurst = visitorRequestLimitBurst conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish + conf.VisitorRequestExemptIPAddrs = visitorRequestLimitExemptIPs conf.VisitorEmailLimitBurst = visitorEmailLimitBurst conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish conf.BehindProxy = behindProxy diff --git a/server/config.go b/server/config.go index 1d72dfb8..4ffea603 100644 --- a/server/config.go +++ b/server/config.go @@ -83,6 +83,7 @@ type Config struct { VisitorAttachmentDailyBandwidthLimit int VisitorRequestLimitBurst int VisitorRequestLimitReplenish time.Duration + VisitorRequestExemptIPAddrs []string VisitorEmailLimitBurst int VisitorEmailLimitReplenish time.Duration BehindProxy bool @@ -120,6 +121,7 @@ func NewConfig() *Config { VisitorAttachmentDailyBandwidthLimit: DefaultVisitorAttachmentDailyBandwidthLimit, VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst, VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish, + VisitorRequestExemptIPAddrs: make([]string, 0), VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst, VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish, BehindProxy: false, diff --git a/server/server.go b/server/server.go index 6854dd2e..5dd4bfce 100644 --- a/server/server.go +++ b/server/server.go @@ -251,16 +251,17 @@ func (s *Server) Stop() { } func (s *Server) handle(w http.ResponseWriter, r *http.Request) { - if err := s.handleInternal(w, r); err != nil { + v := s.visitor(r) + if err := s.handleInternal(w, r, v); err != nil { if websocket.IsWebSocketUpgrade(r) { - log.Printf("[%s] WS %s %s - %s", r.RemoteAddr, r.Method, r.URL.Path, err.Error()) + log.Printf("[%s] WS %s %s - %s", v.ip, r.Method, r.URL.Path, err.Error()) return // Do not attempt to write to upgraded connection } httpErr, ok := err.(*errHTTP) if !ok { httpErr = errHTTPInternalError } - log.Printf("[%s] HTTP %s %s - %d - %d - %s", r.RemoteAddr, r.Method, r.URL.Path, httpErr.HTTPCode, httpErr.Code, err.Error()) + log.Printf("[%s] HTTP %s %s - %d - %d - %s", v.ip, r.Method, r.URL.Path, httpErr.HTTPCode, httpErr.Code, err.Error()) w.Header().Set("Content-Type", "application/json") w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests w.WriteHeader(httpErr.HTTPCode) @@ -268,8 +269,7 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) { } } -func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error { - v := s.visitor(r) +func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visitor) error { if r.Method == http.MethodGet && r.URL.Path == "/" { return s.handleHome(w, r) } else if r.Method == http.MethodGet && r.URL.Path == "/example.html" { @@ -404,14 +404,14 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito if s.firebase != nil && firebase && !delayed { go func() { if err := s.firebase(m); err != nil { - log.Printf("Unable to publish to Firebase: %v", err.Error()) + log.Printf("[%s] FB - Unable to publish to Firebase: %v", v.ip, err.Error()) } }() } if s.mailer != nil && email != "" && !delayed { go func() { if err := s.mailer.Send(v.ip, email, m); err != nil { - log.Printf("Unable to send email: %v", err.Error()) + log.Printf("[%s] MAIL - Unable to send email: %v", v.ip, err.Error()) } }() } @@ -1063,7 +1063,9 @@ func (s *Server) sendDelayedMessages() error { func (s *Server) limitRequests(next handleFunc) handleFunc { return func(w http.ResponseWriter, r *http.Request, v *visitor) error { - if err := v.RequestAllowed(); err != nil { + if util.InStringList(s.config.VisitorRequestExemptIPAddrs, v.ip) { + return next(w, r, v) + } else if err := v.RequestAllowed(); err != nil { return errHTTPTooManyRequestsLimitRequests } return next(w, r, v)