1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2025-06-09 14:34:36 +02:00

Making RateLimiter and FixedLimiter, so they can both work with LimitWriter

This commit is contained in:
Philipp Heckel 2022-01-12 17:03:28 -05:00
parent f6b9ebb693
commit c76e55a1c8
7 changed files with 127 additions and 67 deletions

View file

@ -40,7 +40,7 @@ func newFileCache(dir string, totalSizeLimit int64, fileSizeLimit int64) (*fileC
}, nil
}
func (c *fileCache) Write(id string, in io.Reader, limiters ...*util.Limiter) (int64, error) {
func (c *fileCache) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
if !fileIDRegex.MatchString(id) {
return 0, errInvalidFileID
}
@ -53,7 +53,7 @@ func (c *fileCache) Write(id string, in io.Reader, limiters ...*util.Limiter) (i
return 0, err
}
defer f.Close()
limiters = append(limiters, util.NewLimiter(c.Remaining()), util.NewLimiter(c.fileSizeLimit))
limiters = append(limiters, util.NewFixedLimiter(c.Remaining()), util.NewFixedLimiter(c.fileSizeLimit))
limitWriter := util.NewLimitWriter(f, limiters...)
size, err := io.Copy(limitWriter, in)
if err != nil {

View file

@ -16,7 +16,7 @@ var (
func TestFileCache_Write_Success(t *testing.T) {
dir, c := newTestFileCache(t)
size, err := c.Write("abc", strings.NewReader("normal file"), util.NewLimiter(999))
size, err := c.Write("abc", strings.NewReader("normal file"), util.NewFixedLimiter(999))
require.Nil(t, err)
require.Equal(t, int64(11), size)
require.Equal(t, "normal file", readFile(t, dir+"/abc"))
@ -64,7 +64,7 @@ func TestFileCache_Write_FailedFileSizeLimit(t *testing.T) {
func TestFileCache_Write_FailedAdditionalLimiter(t *testing.T) {
dir, c := newTestFileCache(t)
_, err := c.Write("abc", bytes.NewReader(make([]byte, 1001)), util.NewLimiter(1000))
_, err := c.Write("abc", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000))
require.Equal(t, util.ErrLimitReached, err)
require.NoFileExists(t, dir+"/abc")
}

View file

@ -648,7 +648,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
if m.Message == "" {
m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name)
}
m.Attachment.Size, err = s.fileCache.Write(m.ID, body, util.NewLimiter(remainingVisitorAttachmentSize))
m.Attachment.Size, err = s.fileCache.Write(m.ID, body, util.NewFixedLimiter(remainingVisitorAttachmentSize))
if err == util.ErrLimitReached {
return errHTTPBadRequestAttachmentTooLarge
} else if err != nil {

View file

@ -909,13 +909,6 @@ func toMessage(t *testing.T, s string) *message {
return &m
}
func tempFile(t *testing.T, length int) (filename string, content string) {
filename = filepath.Join(t.TempDir(), util.RandomString(10))
content = util.RandomString(length)
require.Nil(t, os.WriteFile(filename, []byte(content), 0600))
return
}
func toHTTPError(t *testing.T, s string) *errHTTP {
var e errHTTP
require.Nil(t, json.NewDecoder(strings.NewReader(s)).Decode(&e))

View file

@ -24,7 +24,7 @@ type visitor struct {
config *Config
ip string
requests *rate.Limiter
subscriptions *util.Limiter
subscriptions util.Limiter
emails *rate.Limiter
seen time.Time
mu sync.Mutex
@ -35,7 +35,7 @@ func newVisitor(conf *Config, ip string) *visitor {
config: conf,
ip: ip,
requests: rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst),
subscriptions: util.NewLimiter(int64(conf.VisitorSubscriptionLimit)),
subscriptions: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
emails: rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst),
seen: time.Now(),
}
@ -62,7 +62,7 @@ func (v *visitor) EmailAllowed() error {
func (v *visitor) SubscriptionAllowed() error {
v.mu.Lock()
defer v.mu.Unlock()
if err := v.subscriptions.Add(1); err != nil {
if err := v.subscriptions.Allow(1); err != nil {
return errVisitorLimitReached
}
return nil
@ -71,7 +71,7 @@ func (v *visitor) SubscriptionAllowed() error {
func (v *visitor) RemoveSubscription() {
v.mu.Lock()
defer v.mu.Unlock()
v.subscriptions.Sub(1)
v.subscriptions.Allow(-1)
}
func (v *visitor) Keepalive() {