mirror of
				https://github.com/binwiederhier/ntfy.git
				synced 2025-10-31 13:02:24 +01:00 
			
		
		
		
	Email rate limiting + tests
This commit is contained in:
		
							parent
							
								
									873c57b3d8
								
							
						
					
					
						commit
						7280ae1ebc
					
				
					 7 changed files with 183 additions and 34 deletions
				
			
		
							
								
								
									
										28
									
								
								cmd/serve.go
									
										
									
									
									
								
							
							
						
						
									
										28
									
								
								cmd/serve.go
									
										
									
									
									
								
							|  | @ -1,4 +1,3 @@ | |||
| // Package cmd provides the ntfy CLI application | ||||
| package cmd | ||||
| 
 | ||||
| import ( | ||||
|  | @ -22,10 +21,16 @@ var flagsServe = []cli.Flag{ | |||
| 	altsrc.NewDurationFlag(&cli.DurationFlag{Name: "cache-duration", Aliases: []string{"b"}, EnvVars: []string{"NTFY_CACHE_DURATION"}, Value: server.DefaultCacheDuration, Usage: "buffer messages for this time to allow `since` requests"}), | ||||
| 	altsrc.NewDurationFlag(&cli.DurationFlag{Name: "keepalive-interval", Aliases: []string{"k"}, EnvVars: []string{"NTFY_KEEPALIVE_INTERVAL"}, Value: server.DefaultKeepaliveInterval, Usage: "interval of keepalive messages"}), | ||||
| 	altsrc.NewDurationFlag(&cli.DurationFlag{Name: "manager-interval", Aliases: []string{"m"}, EnvVars: []string{"NTFY_MANAGER_INTERVAL"}, Value: server.DefaultManagerInterval, Usage: "interval of for message pruning and stats printing"}), | ||||
| 	altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-addr", EnvVars: []string{"NTFY_SMTP_ADDR"}, Usage: "SMTP address (host:port) to allow email sending"}), | ||||
| 	altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-user", EnvVars: []string{"NTFY_SMTP_USER"}, Usage: "SMTP user (if e-mail sending is enabled)"}), | ||||
| 	altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-pass", EnvVars: []string{"NTFY_SMTP_PASS"}, Usage: "SMTP password (if e-mail sending is enabled)"}), | ||||
| 	altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-from", EnvVars: []string{"NTFY_SMTP_FROM"}, Usage: "SMTP sender address (if e-mail sending is enabled)"}), | ||||
| 	altsrc.NewIntFlag(&cli.IntFlag{Name: "global-topic-limit", Aliases: []string{"T"}, EnvVars: []string{"NTFY_GLOBAL_TOPIC_LIMIT"}, Value: server.DefaultGlobalTopicLimit, Usage: "total number of topics allowed"}), | ||||
| 	altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-subscription-limit", Aliases: []string{"V"}, EnvVars: []string{"NTFY_VISITOR_SUBSCRIPTION_LIMIT"}, Value: server.DefaultVisitorSubscriptionLimit, Usage: "number of subscriptions per visitor"}), | ||||
| 	altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-request-limit-burst", Aliases: []string{"B"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_BURST"}, Value: server.DefaultVisitorRequestLimitBurst, Usage: "initial limit of requests per visitor"}), | ||||
| 	altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-request-limit-replenish", Aliases: []string{"R"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_REPLENISH"}, Value: server.DefaultVisitorRequestLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}), | ||||
| 	altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-subscription-limit", EnvVars: []string{"NTFY_VISITOR_SUBSCRIPTION_LIMIT"}, Value: server.DefaultVisitorSubscriptionLimit, Usage: "number of subscriptions per visitor"}), | ||||
| 	altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-request-limit-burst", EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_BURST"}, Value: server.DefaultVisitorRequestLimitBurst, Usage: "initial limit of requests per visitor"}), | ||||
| 	altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-request-limit-replenish", EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_REPLENISH"}, Value: server.DefaultVisitorRequestLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}), | ||||
| 	altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}), | ||||
| 	altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-email-limit-replenish", EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: server.DefaultVisitorEmailLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}), | ||||
| 	altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}), | ||||
| } | ||||
| 
 | ||||
|  | @ -61,10 +66,16 @@ func execServe(c *cli.Context) error { | |||
| 	cacheDuration := c.Duration("cache-duration") | ||||
| 	keepaliveInterval := c.Duration("keepalive-interval") | ||||
| 	managerInterval := c.Duration("manager-interval") | ||||
| 	smtpAddr := c.String("smtp-addr") | ||||
| 	smtpUser := c.String("smtp-user") | ||||
| 	smtpPass := c.String("smtp-pass") | ||||
| 	smtpFrom := c.String("smtp-from") | ||||
| 	globalTopicLimit := c.Int("global-topic-limit") | ||||
| 	visitorSubscriptionLimit := c.Int("visitor-subscription-limit") | ||||
| 	visitorRequestLimitBurst := c.Int("visitor-request-limit-burst") | ||||
| 	visitorRequestLimitReplenish := c.Duration("visitor-request-limit-replenish") | ||||
| 	visitorEmailLimitBurst := c.Int("visitor-email-limit-burst") | ||||
| 	visitorEmailLimitReplenish := c.Duration("visitor-email-limit-replenish") | ||||
| 	behindProxy := c.Bool("behind-proxy") | ||||
| 
 | ||||
| 	// Check values | ||||
|  | @ -82,6 +93,8 @@ func execServe(c *cli.Context) error { | |||
| 		return errors.New("if set, certificate file must exist") | ||||
| 	} else if listenHTTPS != "" && (keyFile == "" || certFile == "") { | ||||
| 		return errors.New("if listen-https is set, both key-file and cert-file must be set") | ||||
| 	} else if smtpAddr != "" && (smtpUser == "" || smtpPass == "" || smtpFrom == "") { | ||||
| 		return errors.New("if smtp-addr is set, smtp-user, smtp-pass and smtp-from must also be set") | ||||
| 	} | ||||
| 
 | ||||
| 	// Run server | ||||
|  | @ -95,11 +108,16 @@ func execServe(c *cli.Context) error { | |||
| 	conf.CacheDuration = cacheDuration | ||||
| 	conf.KeepaliveInterval = keepaliveInterval | ||||
| 	conf.ManagerInterval = managerInterval | ||||
| 	//XXXXXXXXX | ||||
| 	conf.SMTPAddr = smtpAddr | ||||
| 	conf.SMTPUser = smtpUser | ||||
| 	conf.SMTPPass = smtpPass | ||||
| 	conf.SMTPFrom = smtpFrom | ||||
| 	conf.GlobalTopicLimit = globalTopicLimit | ||||
| 	conf.VisitorSubscriptionLimit = visitorSubscriptionLimit | ||||
| 	conf.VisitorRequestLimitBurst = visitorRequestLimitBurst | ||||
| 	conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish | ||||
| 	conf.VisitorEmailLimitBurst = visitorEmailLimitBurst | ||||
| 	conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish | ||||
| 	conf.BehindProxy = behindProxy | ||||
| 	s, err := server.New(conf) | ||||
| 	if err != nil { | ||||
|  |  | |||
|  | @ -20,11 +20,14 @@ const ( | |||
| // Defines all the limits | ||||
| // - global topic limit: max number of topics overall | ||||
| // - per visitor request limit: max number of PUT/GET/.. requests (here: 60 requests bucket, replenished at a rate of one per 10 seconds) | ||||
| // - per visitor email limit: max number of emails (here: 16 email bucket, replenished at a rate of one per hour) | ||||
| // - per visitor subscription limit: max number of subscriptions (active HTTP connections) per per-visitor/IP | ||||
| const ( | ||||
| 	DefaultGlobalTopicLimit             = 5000 | ||||
| 	DefaultVisitorRequestLimitBurst     = 60 | ||||
| 	DefaultVisitorRequestLimitReplenish = 10 * time.Second | ||||
| 	DefaultVisitorEmailLimitBurst       = 16 | ||||
| 	DefaultVisitorEmailLimitReplenish   = time.Hour | ||||
| 	DefaultVisitorSubscriptionLimit     = 30 | ||||
| ) | ||||
| 
 | ||||
|  | @ -51,6 +54,8 @@ type Config struct { | |||
| 	GlobalTopicLimit             int | ||||
| 	VisitorRequestLimitBurst     int | ||||
| 	VisitorRequestLimitReplenish time.Duration | ||||
| 	VisitorEmailLimitBurst       int | ||||
| 	VisitorEmailLimitReplenish   time.Duration | ||||
| 	VisitorSubscriptionLimit     int | ||||
| 	BehindProxy                  bool | ||||
| } | ||||
|  | @ -75,6 +80,8 @@ func NewConfig() *Config { | |||
| 		GlobalTopicLimit:             DefaultGlobalTopicLimit, | ||||
| 		VisitorRequestLimitBurst:     DefaultVisitorRequestLimitBurst, | ||||
| 		VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish, | ||||
| 		VisitorEmailLimitBurst:       DefaultVisitorEmailLimitBurst, | ||||
| 		VisitorEmailLimitReplenish:   DefaultVisitorEmailLimitReplenish, | ||||
| 		VisitorSubscriptionLimit:     DefaultVisitorSubscriptionLimit, | ||||
| 		BehindProxy:                  false, | ||||
| 	} | ||||
|  |  | |||
							
								
								
									
										35
									
								
								server/mailer.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								server/mailer.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,35 @@ | |||
| package server | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net" | ||||
| 	"net/smtp" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| type mailer interface { | ||||
| 	Send(to string, m *message) error | ||||
| } | ||||
| 
 | ||||
| type smtpMailer struct { | ||||
| 	config *Config | ||||
| } | ||||
| 
 | ||||
| func (s *smtpMailer) Send(to string, m *message) error { | ||||
| 	host, _, err := net.SplitHostPort(s.config.SMTPAddr) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	subject := m.Title | ||||
| 	if subject == "" { | ||||
| 		subject = m.Message | ||||
| 	} | ||||
| 	subject += " - " + m.Topic | ||||
| 	subject = strings.ReplaceAll(strings.ReplaceAll(subject, "\r", ""), "\n", " ") | ||||
| 	msg := []byte(fmt.Sprintf("From: %s\r\n"+ | ||||
| 		"To: %s\r\n"+ | ||||
| 		"Subject: %s\r\n\r\n"+ | ||||
| 		"%s\r\n", s.config.SMTPFrom, to, subject, m.Message)) | ||||
| 	auth := smtp.PlainAuth("", s.config.SMTPUser, s.config.SMTPPass, host) | ||||
| 	return smtp.SendMail(s.config.SMTPAddr, auth, s.config.SMTPFrom, []string{to}, msg) | ||||
| } | ||||
|  | @ -3,7 +3,7 @@ package server | |||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"embed" // required for go:embed | ||||
| 	"embed" | ||||
| 	"encoding/json" | ||||
| 	firebase "firebase.google.com/go" | ||||
| 	"firebase.google.com/go/messaging" | ||||
|  | @ -15,7 +15,6 @@ import ( | |||
| 	"log" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/smtp" | ||||
| 	"regexp" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
|  | @ -34,6 +33,7 @@ type Server struct { | |||
| 	topics      map[string]*topic | ||||
| 	visitors    map[string]*visitor | ||||
| 	firebase    subscriber | ||||
| 	mailer      mailer | ||||
| 	messages    int64 | ||||
| 	cache       cache | ||||
| 	closeChan   chan bool | ||||
|  | @ -111,6 +111,7 @@ var ( | |||
| 
 | ||||
| const ( | ||||
| 	firebaseControlTopic = "~control" // See Android if changed | ||||
| 	emptyMessageBody     = "triggered" | ||||
| ) | ||||
| 
 | ||||
| // New instantiates a new Server. It creates the cache and adds a Firebase | ||||
|  | @ -124,6 +125,10 @@ func New(conf *Config) (*Server, error) { | |||
| 			return nil, err | ||||
| 		} | ||||
| 	} | ||||
| 	var mailer mailer | ||||
| 	if conf.SMTPAddr != "" { | ||||
| 		mailer = &smtpMailer{config: conf} | ||||
| 	} | ||||
| 	cache, err := createCache(conf) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
|  | @ -136,6 +141,7 @@ func New(conf *Config) (*Server, error) { | |||
| 		config:   conf, | ||||
| 		cache:    cache, | ||||
| 		firebase: firebaseSubscriber, | ||||
| 		mailer:   mailer, | ||||
| 		topics:   topics, | ||||
| 		visitors: make(map[string]*visitor), | ||||
| 	}, nil | ||||
|  | @ -189,23 +195,6 @@ func createFirebaseSubscriber(conf *Config) (subscriber, error) { | |||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
| func (s *Server) sendMail(to string, m *message) error { | ||||
| 	host, _, err := net.SplitHostPort(s.config.SMTPAddr) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	subject := m.Title | ||||
| 	if subject == "" { | ||||
| 		subject = m.Message | ||||
| 	} | ||||
| 	msg := []byte(fmt.Sprintf("From: %s\r\n"+ | ||||
| 		"To: %s\r\n"+ | ||||
| 		"Subject: %s\r\n\r\n"+ | ||||
| 		"%s\r\n", s.config.SMTPFrom, to, subject, m.Message)) | ||||
| 	auth := smtp.PlainAuth("", s.config.SMTPUser, s.config.SMTPPass, host) | ||||
| 	return smtp.SendMail(s.config.SMTPAddr, auth, s.config.SMTPFrom, []string{to}, msg) | ||||
| } | ||||
| 
 | ||||
| // Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts | ||||
| // a manager go routine to print stats and prune messages. | ||||
| func (s *Server) Run() error { | ||||
|  | @ -314,7 +303,7 @@ func (s *Server) handleDocs(w http.ResponseWriter, r *http.Request) error { | |||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, _ *visitor) error { | ||||
| func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error { | ||||
| 	t, err := s.topicFromPath(r.URL.Path) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
|  | @ -329,8 +318,16 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, _ *visito | |||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if email != "" { | ||||
| 		if err := v.EmailAllowed(); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	if s.mailer == nil && email != "" { | ||||
| 		return errHTTPBadRequest | ||||
| 	} | ||||
| 	if m.Message == "" { | ||||
| 		m.Message = "triggered" | ||||
| 		m.Message = emptyMessageBody | ||||
| 	} | ||||
| 	delayed := m.Time > time.Now().Unix() | ||||
| 	if !delayed { | ||||
|  | @ -345,9 +342,9 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, _ *visito | |||
| 			} | ||||
| 		}() | ||||
| 	} | ||||
| 	if s.config.SMTPAddr != "" && email != "" && !delayed { | ||||
| 	if s.mailer != nil && email != "" && !delayed { | ||||
| 		go func() { | ||||
| 			if err := s.sendMail(email, m); err != nil { | ||||
| 			if err := s.mailer.Send(email, m); err != nil { | ||||
| 				log.Printf("Unable to send email: %v", err.Error()) | ||||
| 			} | ||||
| 		}() | ||||
|  | @ -369,7 +366,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, _ *visito | |||
| func (s *Server) parseParams(r *http.Request, m *message) (cache bool, firebase bool, email string, err error) { | ||||
| 	cache = readParam(r, "x-cache", "cache") != "no" | ||||
| 	firebase = readParam(r, "x-firebase", "firebase") != "no" | ||||
| 	email = readParam(r, "x-email", "email", "mail", "e") | ||||
| 	email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e") | ||||
| 	m.Title = readParam(r, "x-title", "title", "t") | ||||
| 	messageStr := readParam(r, "x-message", "message", "m") | ||||
| 	if messageStr != "" { | ||||
|  | @ -391,6 +388,9 @@ func (s *Server) parseParams(r *http.Request, m *message) (cache bool, firebase | |||
| 		if !cache { | ||||
| 			return false, false, "", errHTTPBadRequest | ||||
| 		} | ||||
| 		if email != "" { | ||||
| 			return false, false, "", errHTTPBadRequest // we cannot store the email address (yet) | ||||
| 		} | ||||
| 		delay, err := util.ParseFutureTime(delayStr, time.Now()) | ||||
| 		if err != nil { | ||||
| 			return false, false, "", errHTTPBadRequest | ||||
|  | @ -740,7 +740,7 @@ func (s *Server) sendDelayedMessages() error { | |||
| 					log.Printf("unable to publish to Firebase: %v", err.Error()) | ||||
| 				} | ||||
| 			} | ||||
| 			// FIXME delayed email | ||||
| 			// TODO delayed email sending | ||||
| 		} | ||||
| 		if err := s.cache.MarkPublished(m); err != nil { | ||||
| 			return err | ||||
|  |  | |||
|  | @ -61,6 +61,13 @@ | |||
| # visitor-request-limit-burst: 60 | ||||
| # visitor-request-limit-replenish: 10s | ||||
| 
 | ||||
| # Rate limiting: Allowed emails per visitor: | ||||
| # - visitor-email-limit-burst is the initial bucket of emails each visitor has | ||||
| # - visitor-email-limit-replenish is the rate at which the bucket is refilled | ||||
| # | ||||
| # visitor-email-limit-burst: 16 | ||||
| # visitor-email-limit-replenish: 1h | ||||
| 
 | ||||
| # If set, the X-Forwarded-For header is used to determine the visitor IP address | ||||
| # instead of the remote address of the connection. | ||||
| # | ||||
|  |  | |||
|  | @ -508,6 +508,76 @@ func TestServer_Curl_Publish_Poll(t *testing.T) { | |||
| } | ||||
| */ | ||||
| 
 | ||||
| type testMailer struct { | ||||
| 	count int | ||||
| } | ||||
| 
 | ||||
| func (t *testMailer) Send(to string, m *message) error { | ||||
| 	t.count++ | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func TestServer_PublishTooManyEmails_Defaults(t *testing.T) { | ||||
| 	s := newTestServer(t, newTestConfig(t)) | ||||
| 	s.mailer = &testMailer{} | ||||
| 	for i := 0; i < 16; i++ { | ||||
| 		response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), map[string]string{ | ||||
| 			"E-Mail": "test@example.com", | ||||
| 		}) | ||||
| 		require.Equal(t, 200, response.Code) | ||||
| 	} | ||||
| 	response := request(t, s, "PUT", "/mytopic", "one too many", map[string]string{ | ||||
| 		"E-Mail": "test@example.com", | ||||
| 	}) | ||||
| 	require.Equal(t, 429, response.Code) | ||||
| } | ||||
| 
 | ||||
| func TestServer_PublishTooManyEmails_Replenish(t *testing.T) { | ||||
| 	c := newTestConfig(t) | ||||
| 	c.VisitorEmailLimitReplenish = 500 * time.Millisecond | ||||
| 	s := newTestServer(t, c) | ||||
| 	s.mailer = &testMailer{} | ||||
| 	for i := 0; i < 16; i++ { | ||||
| 		response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), map[string]string{ | ||||
| 			"E-Mail": "test@example.com", | ||||
| 		}) | ||||
| 		require.Equal(t, 200, response.Code) | ||||
| 	} | ||||
| 	response := request(t, s, "PUT", "/mytopic", "one too many", map[string]string{ | ||||
| 		"E-Mail": "test@example.com", | ||||
| 	}) | ||||
| 	require.Equal(t, 429, response.Code) | ||||
| 
 | ||||
| 	time.Sleep(510 * time.Millisecond) | ||||
| 	response = request(t, s, "PUT", "/mytopic", "this should be okay again too many", map[string]string{ | ||||
| 		"E-Mail": "test@example.com", | ||||
| 	}) | ||||
| 	require.Equal(t, 200, response.Code) | ||||
| 
 | ||||
| 	response = request(t, s, "PUT", "/mytopic", "and bad again", map[string]string{ | ||||
| 		"E-Mail": "test@example.com", | ||||
| 	}) | ||||
| 	require.Equal(t, 429, response.Code) | ||||
| } | ||||
| 
 | ||||
| func TestServer_PublishDelayedEmail_Fail(t *testing.T) { | ||||
| 	s := newTestServer(t, newTestConfig(t)) | ||||
| 	s.mailer = &testMailer{} | ||||
| 	response := request(t, s, "PUT", "/mytopic", "fail", map[string]string{ | ||||
| 		"E-Mail": "test@example.com", | ||||
| 		"Delay":  "20 min", | ||||
| 	}) | ||||
| 	require.Equal(t, 400, response.Code) | ||||
| } | ||||
| 
 | ||||
| func TestServer_PublishEmailNoMailer_Fail(t *testing.T) { | ||||
| 	s := newTestServer(t, newTestConfig(t)) | ||||
| 	response := request(t, s, "PUT", "/mytopic", "fail", map[string]string{ | ||||
| 		"E-Mail": "test@example.com", | ||||
| 	}) | ||||
| 	require.Equal(t, 400, response.Code) | ||||
| } | ||||
| 
 | ||||
| func newTestConfig(t *testing.T) *Config { | ||||
| 	conf := NewConfig() | ||||
| 	conf.CacheFile = filepath.Join(t.TempDir(), "cache.db") | ||||
|  |  | |||
|  | @ -8,13 +8,17 @@ import ( | |||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	visitorExpungeAfter = 30 * time.Minute | ||||
| 	// visitorExpungeAfter defines how long a visitor is active before it is removed from memory. This number | ||||
| 	// has to be very high to prevent e-mail abuse, but it doesn't really affect the other limits anyway, since | ||||
| 	// they are replenished faster (typically). | ||||
| 	visitorExpungeAfter = 24 * time.Hour | ||||
| ) | ||||
| 
 | ||||
| // visitor represents an API user, and its associated rate.Limiter used for rate limiting | ||||
| type visitor struct { | ||||
| 	config        *Config | ||||
| 	limiter       *rate.Limiter | ||||
| 	requests      *rate.Limiter | ||||
| 	emails        *rate.Limiter | ||||
| 	subscriptions *util.Limiter | ||||
| 	seen          time.Time | ||||
| 	mu            sync.Mutex | ||||
|  | @ -23,14 +27,22 @@ type visitor struct { | |||
| func newVisitor(conf *Config) *visitor { | ||||
| 	return &visitor{ | ||||
| 		config:        conf, | ||||
| 		limiter:       rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst), | ||||
| 		requests:      rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst), | ||||
| 		emails:        rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst), | ||||
| 		subscriptions: util.NewLimiter(int64(conf.VisitorSubscriptionLimit)), | ||||
| 		seen:          time.Now(), | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (v *visitor) RequestAllowed() error { | ||||
| 	if !v.limiter.Allow() { | ||||
| 	if !v.requests.Allow() { | ||||
| 		return errHTTPTooManyRequests | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (v *visitor) EmailAllowed() error { | ||||
| 	if !v.emails.Allow() { | ||||
| 		return errHTTPTooManyRequests | ||||
| 	} | ||||
| 	return nil | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Philipp Heckel
						Philipp Heckel