From 7280ae1ebcc6005226e3ff2479c14842a578fb64 Mon Sep 17 00:00:00 2001 From: Philipp Heckel Date: Fri, 24 Dec 2021 00:03:04 +0100 Subject: [PATCH] Email rate limiting + tests --- cmd/serve.go | 28 +++++++++++++---- server/config.go | 7 +++++ server/mailer.go | 35 ++++++++++++++++++++++ server/server.go | 50 +++++++++++++++---------------- server/server.yml | 7 +++++ server/server_test.go | 70 +++++++++++++++++++++++++++++++++++++++++++ server/visitor.go | 20 ++++++++++--- 7 files changed, 183 insertions(+), 34 deletions(-) create mode 100644 server/mailer.go diff --git a/cmd/serve.go b/cmd/serve.go index 18b7e20e..4f6e0809 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -1,4 +1,3 @@ -// Package cmd provides the ntfy CLI application package cmd import ( @@ -22,10 +21,16 @@ var flagsServe = []cli.Flag{ altsrc.NewDurationFlag(&cli.DurationFlag{Name: "cache-duration", Aliases: []string{"b"}, EnvVars: []string{"NTFY_CACHE_DURATION"}, Value: server.DefaultCacheDuration, Usage: "buffer messages for this time to allow `since` requests"}), altsrc.NewDurationFlag(&cli.DurationFlag{Name: "keepalive-interval", Aliases: []string{"k"}, EnvVars: []string{"NTFY_KEEPALIVE_INTERVAL"}, Value: server.DefaultKeepaliveInterval, Usage: "interval of keepalive messages"}), altsrc.NewDurationFlag(&cli.DurationFlag{Name: "manager-interval", Aliases: []string{"m"}, EnvVars: []string{"NTFY_MANAGER_INTERVAL"}, Value: server.DefaultManagerInterval, Usage: "interval of for message pruning and stats printing"}), + altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-addr", EnvVars: []string{"NTFY_SMTP_ADDR"}, Usage: "SMTP address (host:port) to allow email sending"}), + altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-user", EnvVars: []string{"NTFY_SMTP_USER"}, Usage: "SMTP user (if e-mail sending is enabled)"}), + altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-pass", EnvVars: []string{"NTFY_SMTP_PASS"}, Usage: "SMTP password (if e-mail sending is enabled)"}), + altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-from", EnvVars: []string{"NTFY_SMTP_FROM"}, Usage: "SMTP sender address (if e-mail sending is enabled)"}), altsrc.NewIntFlag(&cli.IntFlag{Name: "global-topic-limit", Aliases: []string{"T"}, EnvVars: []string{"NTFY_GLOBAL_TOPIC_LIMIT"}, Value: server.DefaultGlobalTopicLimit, Usage: "total number of topics allowed"}), - altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-subscription-limit", Aliases: []string{"V"}, EnvVars: []string{"NTFY_VISITOR_SUBSCRIPTION_LIMIT"}, Value: server.DefaultVisitorSubscriptionLimit, Usage: "number of subscriptions per visitor"}), - altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-request-limit-burst", Aliases: []string{"B"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_BURST"}, Value: server.DefaultVisitorRequestLimitBurst, Usage: "initial limit of requests per visitor"}), - altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-request-limit-replenish", Aliases: []string{"R"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_REPLENISH"}, Value: server.DefaultVisitorRequestLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}), + altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-subscription-limit", EnvVars: []string{"NTFY_VISITOR_SUBSCRIPTION_LIMIT"}, Value: server.DefaultVisitorSubscriptionLimit, Usage: "number of subscriptions per visitor"}), + altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-request-limit-burst", EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_BURST"}, Value: server.DefaultVisitorRequestLimitBurst, Usage: "initial limit of requests per visitor"}), + altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-request-limit-replenish", EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_REPLENISH"}, Value: server.DefaultVisitorRequestLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}), + altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}), + altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-email-limit-replenish", EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: server.DefaultVisitorEmailLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}), altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}), } @@ -61,10 +66,16 @@ func execServe(c *cli.Context) error { cacheDuration := c.Duration("cache-duration") keepaliveInterval := c.Duration("keepalive-interval") managerInterval := c.Duration("manager-interval") + smtpAddr := c.String("smtp-addr") + smtpUser := c.String("smtp-user") + smtpPass := c.String("smtp-pass") + smtpFrom := c.String("smtp-from") globalTopicLimit := c.Int("global-topic-limit") visitorSubscriptionLimit := c.Int("visitor-subscription-limit") visitorRequestLimitBurst := c.Int("visitor-request-limit-burst") visitorRequestLimitReplenish := c.Duration("visitor-request-limit-replenish") + visitorEmailLimitBurst := c.Int("visitor-email-limit-burst") + visitorEmailLimitReplenish := c.Duration("visitor-email-limit-replenish") behindProxy := c.Bool("behind-proxy") // Check values @@ -82,6 +93,8 @@ func execServe(c *cli.Context) error { return errors.New("if set, certificate file must exist") } else if listenHTTPS != "" && (keyFile == "" || certFile == "") { return errors.New("if listen-https is set, both key-file and cert-file must be set") + } else if smtpAddr != "" && (smtpUser == "" || smtpPass == "" || smtpFrom == "") { + return errors.New("if smtp-addr is set, smtp-user, smtp-pass and smtp-from must also be set") } // Run server @@ -95,11 +108,16 @@ func execServe(c *cli.Context) error { conf.CacheDuration = cacheDuration conf.KeepaliveInterval = keepaliveInterval conf.ManagerInterval = managerInterval - //XXXXXXXXX + conf.SMTPAddr = smtpAddr + conf.SMTPUser = smtpUser + conf.SMTPPass = smtpPass + conf.SMTPFrom = smtpFrom conf.GlobalTopicLimit = globalTopicLimit conf.VisitorSubscriptionLimit = visitorSubscriptionLimit conf.VisitorRequestLimitBurst = visitorRequestLimitBurst conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish + conf.VisitorEmailLimitBurst = visitorEmailLimitBurst + conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish conf.BehindProxy = behindProxy s, err := server.New(conf) if err != nil { diff --git a/server/config.go b/server/config.go index 58809df6..82214788 100644 --- a/server/config.go +++ b/server/config.go @@ -20,11 +20,14 @@ const ( // Defines all the limits // - global topic limit: max number of topics overall // - per visitor request limit: max number of PUT/GET/.. requests (here: 60 requests bucket, replenished at a rate of one per 10 seconds) +// - per visitor email limit: max number of emails (here: 16 email bucket, replenished at a rate of one per hour) // - per visitor subscription limit: max number of subscriptions (active HTTP connections) per per-visitor/IP const ( DefaultGlobalTopicLimit = 5000 DefaultVisitorRequestLimitBurst = 60 DefaultVisitorRequestLimitReplenish = 10 * time.Second + DefaultVisitorEmailLimitBurst = 16 + DefaultVisitorEmailLimitReplenish = time.Hour DefaultVisitorSubscriptionLimit = 30 ) @@ -51,6 +54,8 @@ type Config struct { GlobalTopicLimit int VisitorRequestLimitBurst int VisitorRequestLimitReplenish time.Duration + VisitorEmailLimitBurst int + VisitorEmailLimitReplenish time.Duration VisitorSubscriptionLimit int BehindProxy bool } @@ -75,6 +80,8 @@ func NewConfig() *Config { GlobalTopicLimit: DefaultGlobalTopicLimit, VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst, VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish, + VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst, + VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish, VisitorSubscriptionLimit: DefaultVisitorSubscriptionLimit, BehindProxy: false, } diff --git a/server/mailer.go b/server/mailer.go new file mode 100644 index 00000000..1d3af232 --- /dev/null +++ b/server/mailer.go @@ -0,0 +1,35 @@ +package server + +import ( + "fmt" + "net" + "net/smtp" + "strings" +) + +type mailer interface { + Send(to string, m *message) error +} + +type smtpMailer struct { + config *Config +} + +func (s *smtpMailer) Send(to string, m *message) error { + host, _, err := net.SplitHostPort(s.config.SMTPAddr) + if err != nil { + return err + } + subject := m.Title + if subject == "" { + subject = m.Message + } + subject += " - " + m.Topic + subject = strings.ReplaceAll(strings.ReplaceAll(subject, "\r", ""), "\n", " ") + msg := []byte(fmt.Sprintf("From: %s\r\n"+ + "To: %s\r\n"+ + "Subject: %s\r\n\r\n"+ + "%s\r\n", s.config.SMTPFrom, to, subject, m.Message)) + auth := smtp.PlainAuth("", s.config.SMTPUser, s.config.SMTPPass, host) + return smtp.SendMail(s.config.SMTPAddr, auth, s.config.SMTPFrom, []string{to}, msg) +} diff --git a/server/server.go b/server/server.go index 655c686b..ee95e5e3 100644 --- a/server/server.go +++ b/server/server.go @@ -3,7 +3,7 @@ package server import ( "bytes" "context" - "embed" // required for go:embed + "embed" "encoding/json" firebase "firebase.google.com/go" "firebase.google.com/go/messaging" @@ -15,7 +15,6 @@ import ( "log" "net" "net/http" - "net/smtp" "regexp" "strconv" "strings" @@ -34,6 +33,7 @@ type Server struct { topics map[string]*topic visitors map[string]*visitor firebase subscriber + mailer mailer messages int64 cache cache closeChan chan bool @@ -111,6 +111,7 @@ var ( const ( firebaseControlTopic = "~control" // See Android if changed + emptyMessageBody = "triggered" ) // New instantiates a new Server. It creates the cache and adds a Firebase @@ -124,6 +125,10 @@ func New(conf *Config) (*Server, error) { return nil, err } } + var mailer mailer + if conf.SMTPAddr != "" { + mailer = &smtpMailer{config: conf} + } cache, err := createCache(conf) if err != nil { return nil, err @@ -136,6 +141,7 @@ func New(conf *Config) (*Server, error) { config: conf, cache: cache, firebase: firebaseSubscriber, + mailer: mailer, topics: topics, visitors: make(map[string]*visitor), }, nil @@ -189,23 +195,6 @@ func createFirebaseSubscriber(conf *Config) (subscriber, error) { }, nil } -func (s *Server) sendMail(to string, m *message) error { - host, _, err := net.SplitHostPort(s.config.SMTPAddr) - if err != nil { - return err - } - subject := m.Title - if subject == "" { - subject = m.Message - } - msg := []byte(fmt.Sprintf("From: %s\r\n"+ - "To: %s\r\n"+ - "Subject: %s\r\n\r\n"+ - "%s\r\n", s.config.SMTPFrom, to, subject, m.Message)) - auth := smtp.PlainAuth("", s.config.SMTPUser, s.config.SMTPPass, host) - return smtp.SendMail(s.config.SMTPAddr, auth, s.config.SMTPFrom, []string{to}, msg) -} - // Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts // a manager go routine to print stats and prune messages. func (s *Server) Run() error { @@ -314,7 +303,7 @@ func (s *Server) handleDocs(w http.ResponseWriter, r *http.Request) error { return nil } -func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, _ *visitor) error { +func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error { t, err := s.topicFromPath(r.URL.Path) if err != nil { return err @@ -329,8 +318,16 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, _ *visito if err != nil { return err } + if email != "" { + if err := v.EmailAllowed(); err != nil { + return err + } + } + if s.mailer == nil && email != "" { + return errHTTPBadRequest + } if m.Message == "" { - m.Message = "triggered" + m.Message = emptyMessageBody } delayed := m.Time > time.Now().Unix() if !delayed { @@ -345,9 +342,9 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, _ *visito } }() } - if s.config.SMTPAddr != "" && email != "" && !delayed { + if s.mailer != nil && email != "" && !delayed { go func() { - if err := s.sendMail(email, m); err != nil { + if err := s.mailer.Send(email, m); err != nil { log.Printf("Unable to send email: %v", err.Error()) } }() @@ -369,7 +366,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, _ *visito func (s *Server) parseParams(r *http.Request, m *message) (cache bool, firebase bool, email string, err error) { cache = readParam(r, "x-cache", "cache") != "no" firebase = readParam(r, "x-firebase", "firebase") != "no" - email = readParam(r, "x-email", "email", "mail", "e") + email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e") m.Title = readParam(r, "x-title", "title", "t") messageStr := readParam(r, "x-message", "message", "m") if messageStr != "" { @@ -391,6 +388,9 @@ func (s *Server) parseParams(r *http.Request, m *message) (cache bool, firebase if !cache { return false, false, "", errHTTPBadRequest } + if email != "" { + return false, false, "", errHTTPBadRequest // we cannot store the email address (yet) + } delay, err := util.ParseFutureTime(delayStr, time.Now()) if err != nil { return false, false, "", errHTTPBadRequest @@ -740,7 +740,7 @@ func (s *Server) sendDelayedMessages() error { log.Printf("unable to publish to Firebase: %v", err.Error()) } } - // FIXME delayed email + // TODO delayed email sending } if err := s.cache.MarkPublished(m); err != nil { return err diff --git a/server/server.yml b/server/server.yml index 1ecc5de6..8bf686de 100644 --- a/server/server.yml +++ b/server/server.yml @@ -61,6 +61,13 @@ # visitor-request-limit-burst: 60 # visitor-request-limit-replenish: 10s +# Rate limiting: Allowed emails per visitor: +# - visitor-email-limit-burst is the initial bucket of emails each visitor has +# - visitor-email-limit-replenish is the rate at which the bucket is refilled +# +# visitor-email-limit-burst: 16 +# visitor-email-limit-replenish: 1h + # If set, the X-Forwarded-For header is used to determine the visitor IP address # instead of the remote address of the connection. # diff --git a/server/server_test.go b/server/server_test.go index 377f1d64..3a9b20b9 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -508,6 +508,76 @@ func TestServer_Curl_Publish_Poll(t *testing.T) { } */ +type testMailer struct { + count int +} + +func (t *testMailer) Send(to string, m *message) error { + t.count++ + return nil +} + +func TestServer_PublishTooManyEmails_Defaults(t *testing.T) { + s := newTestServer(t, newTestConfig(t)) + s.mailer = &testMailer{} + for i := 0; i < 16; i++ { + response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), map[string]string{ + "E-Mail": "test@example.com", + }) + require.Equal(t, 200, response.Code) + } + response := request(t, s, "PUT", "/mytopic", "one too many", map[string]string{ + "E-Mail": "test@example.com", + }) + require.Equal(t, 429, response.Code) +} + +func TestServer_PublishTooManyEmails_Replenish(t *testing.T) { + c := newTestConfig(t) + c.VisitorEmailLimitReplenish = 500 * time.Millisecond + s := newTestServer(t, c) + s.mailer = &testMailer{} + for i := 0; i < 16; i++ { + response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), map[string]string{ + "E-Mail": "test@example.com", + }) + require.Equal(t, 200, response.Code) + } + response := request(t, s, "PUT", "/mytopic", "one too many", map[string]string{ + "E-Mail": "test@example.com", + }) + require.Equal(t, 429, response.Code) + + time.Sleep(510 * time.Millisecond) + response = request(t, s, "PUT", "/mytopic", "this should be okay again too many", map[string]string{ + "E-Mail": "test@example.com", + }) + require.Equal(t, 200, response.Code) + + response = request(t, s, "PUT", "/mytopic", "and bad again", map[string]string{ + "E-Mail": "test@example.com", + }) + require.Equal(t, 429, response.Code) +} + +func TestServer_PublishDelayedEmail_Fail(t *testing.T) { + s := newTestServer(t, newTestConfig(t)) + s.mailer = &testMailer{} + response := request(t, s, "PUT", "/mytopic", "fail", map[string]string{ + "E-Mail": "test@example.com", + "Delay": "20 min", + }) + require.Equal(t, 400, response.Code) +} + +func TestServer_PublishEmailNoMailer_Fail(t *testing.T) { + s := newTestServer(t, newTestConfig(t)) + response := request(t, s, "PUT", "/mytopic", "fail", map[string]string{ + "E-Mail": "test@example.com", + }) + require.Equal(t, 400, response.Code) +} + func newTestConfig(t *testing.T) *Config { conf := NewConfig() conf.CacheFile = filepath.Join(t.TempDir(), "cache.db") diff --git a/server/visitor.go b/server/visitor.go index 9d99e94a..269b3162 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -8,13 +8,17 @@ import ( ) const ( - visitorExpungeAfter = 30 * time.Minute + // visitorExpungeAfter defines how long a visitor is active before it is removed from memory. This number + // has to be very high to prevent e-mail abuse, but it doesn't really affect the other limits anyway, since + // they are replenished faster (typically). + visitorExpungeAfter = 24 * time.Hour ) // visitor represents an API user, and its associated rate.Limiter used for rate limiting type visitor struct { config *Config - limiter *rate.Limiter + requests *rate.Limiter + emails *rate.Limiter subscriptions *util.Limiter seen time.Time mu sync.Mutex @@ -23,14 +27,22 @@ type visitor struct { func newVisitor(conf *Config) *visitor { return &visitor{ config: conf, - limiter: rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst), + requests: rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst), + emails: rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst), subscriptions: util.NewLimiter(int64(conf.VisitorSubscriptionLimit)), seen: time.Now(), } } func (v *visitor) RequestAllowed() error { - if !v.limiter.Allow() { + if !v.requests.Allow() { + return errHTTPTooManyRequests + } + return nil +} + +func (v *visitor) EmailAllowed() error { + if !v.emails.Allow() { return errHTTPTooManyRequests } return nil