diff --git a/server/file_cache.go b/server/file_cache.go index 40346474..ad4961cc 100644 --- a/server/file_cache.go +++ b/server/file_cache.go @@ -40,7 +40,7 @@ func newFileCache(dir string, totalSizeLimit int64, fileSizeLimit int64) (*fileC }, nil } -func (c *fileCache) Write(id string, in io.Reader, limiters ...*util.Limiter) (int64, error) { +func (c *fileCache) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) { if !fileIDRegex.MatchString(id) { return 0, errInvalidFileID } @@ -53,7 +53,7 @@ func (c *fileCache) Write(id string, in io.Reader, limiters ...*util.Limiter) (i return 0, err } defer f.Close() - limiters = append(limiters, util.NewLimiter(c.Remaining()), util.NewLimiter(c.fileSizeLimit)) + limiters = append(limiters, util.NewFixedLimiter(c.Remaining()), util.NewFixedLimiter(c.fileSizeLimit)) limitWriter := util.NewLimitWriter(f, limiters...) size, err := io.Copy(limitWriter, in) if err != nil { diff --git a/server/file_cache_test.go b/server/file_cache_test.go index a0a74085..36d1d1a3 100644 --- a/server/file_cache_test.go +++ b/server/file_cache_test.go @@ -16,7 +16,7 @@ var ( func TestFileCache_Write_Success(t *testing.T) { dir, c := newTestFileCache(t) - size, err := c.Write("abc", strings.NewReader("normal file"), util.NewLimiter(999)) + size, err := c.Write("abc", strings.NewReader("normal file"), util.NewFixedLimiter(999)) require.Nil(t, err) require.Equal(t, int64(11), size) require.Equal(t, "normal file", readFile(t, dir+"/abc")) @@ -64,7 +64,7 @@ func TestFileCache_Write_FailedFileSizeLimit(t *testing.T) { func TestFileCache_Write_FailedAdditionalLimiter(t *testing.T) { dir, c := newTestFileCache(t) - _, err := c.Write("abc", bytes.NewReader(make([]byte, 1001)), util.NewLimiter(1000)) + _, err := c.Write("abc", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000)) require.Equal(t, util.ErrLimitReached, err) require.NoFileExists(t, dir+"/abc") } diff --git a/server/server.go b/server/server.go index c5c398db..33aacbd0 100644 --- a/server/server.go +++ b/server/server.go @@ -648,7 +648,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, if m.Message == "" { m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name) } - m.Attachment.Size, err = s.fileCache.Write(m.ID, body, util.NewLimiter(remainingVisitorAttachmentSize)) + m.Attachment.Size, err = s.fileCache.Write(m.ID, body, util.NewFixedLimiter(remainingVisitorAttachmentSize)) if err == util.ErrLimitReached { return errHTTPBadRequestAttachmentTooLarge } else if err != nil { diff --git a/server/server_test.go b/server/server_test.go index 90fbee00..0d81cb3c 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -909,13 +909,6 @@ func toMessage(t *testing.T, s string) *message { return &m } -func tempFile(t *testing.T, length int) (filename string, content string) { - filename = filepath.Join(t.TempDir(), util.RandomString(10)) - content = util.RandomString(length) - require.Nil(t, os.WriteFile(filename, []byte(content), 0600)) - return -} - func toHTTPError(t *testing.T, s string) *errHTTP { var e errHTTP require.Nil(t, json.NewDecoder(strings.NewReader(s)).Decode(&e)) diff --git a/server/visitor.go b/server/visitor.go index 63478798..26afa136 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -24,7 +24,7 @@ type visitor struct { config *Config ip string requests *rate.Limiter - subscriptions *util.Limiter + subscriptions util.Limiter emails *rate.Limiter seen time.Time mu sync.Mutex @@ -35,7 +35,7 @@ func newVisitor(conf *Config, ip string) *visitor { config: conf, ip: ip, requests: rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst), - subscriptions: util.NewLimiter(int64(conf.VisitorSubscriptionLimit)), + subscriptions: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)), emails: rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst), seen: time.Now(), } @@ -62,7 +62,7 @@ func (v *visitor) EmailAllowed() error { func (v *visitor) SubscriptionAllowed() error { v.mu.Lock() defer v.mu.Unlock() - if err := v.subscriptions.Add(1); err != nil { + if err := v.subscriptions.Allow(1); err != nil { return errVisitorLimitReached } return nil @@ -71,7 +71,7 @@ func (v *visitor) SubscriptionAllowed() error { func (v *visitor) RemoveSubscription() { v.mu.Lock() defer v.mu.Unlock() - v.subscriptions.Sub(1) + v.subscriptions.Allow(-1) } func (v *visitor) Keepalive() { diff --git a/util/limit.go b/util/limit.go index d7c3a8a6..8df768ad 100644 --- a/util/limit.go +++ b/util/limit.go @@ -2,31 +2,39 @@ package util import ( "errors" + "golang.org/x/time/rate" "io" "sync" + "time" ) // ErrLimitReached is the error returned by the Limiter and LimitWriter when the predefined limit has been reached var ErrLimitReached = errors.New("limit reached") -// Limiter is a helper that allows adding values up to a well-defined limit. Once the limit is reached -// ErrLimitReached will be returned. Limiter may be used by multiple goroutines. -type Limiter struct { +// Limiter is an interface that implements a rate limiting mechanism, e.g. based on time or a fixed value +type Limiter interface { + // Allow adds n to the limiters internal value, or returns ErrLimitReached if the limit has been reached + Allow(n int64) error +} + +// FixedLimiter is a helper that allows adding values up to a well-defined limit. Once the limit is reached +// ErrLimitReached will be returned. FixedLimiter may be used by multiple goroutines. +type FixedLimiter struct { value int64 limit int64 mu sync.Mutex } -// NewLimiter creates a new Limiter -func NewLimiter(limit int64) *Limiter { - return &Limiter{ +// NewFixedLimiter creates a new Limiter +func NewFixedLimiter(limit int64) *FixedLimiter { + return &FixedLimiter{ limit: limit, } } -// Add adds n to the limiters internal value, but only if the limit has not been reached. If the limit was +// Allow adds n to the limiters internal value, but only if the limit has not been reached. If the limit was // exceeded after adding n, ErrLimitReached is returned. -func (l *Limiter) Add(n int64) error { +func (l *FixedLimiter) Allow(n int64) error { l.mu.Lock() defer l.mu.Unlock() if l.value+n > l.limit { @@ -36,29 +44,34 @@ func (l *Limiter) Add(n int64) error { return nil } -// Sub subtracts a value from the limiters internal value -func (l *Limiter) Sub(n int64) { - l.Add(-n) +// RateLimiter is a Limiter that wraps a rate.Limiter, allowing a floating time-based limit. +type RateLimiter struct { + limiter *rate.Limiter } -// Set sets the value of the limiter to n. This function ignores the limit. It is meant to set the value -// based on reality. -func (l *Limiter) Set(n int64) { - l.mu.Lock() - l.value = n - l.mu.Unlock() +// NewRateLimiter creates a new RateLimiter +func NewRateLimiter(r rate.Limit, b int) *RateLimiter { + return &RateLimiter{ + limiter: rate.NewLimiter(r, b), + } } -// Value returns the internal value of the limiter -func (l *Limiter) Value() int64 { - l.mu.Lock() - defer l.mu.Unlock() - return l.value +// NewBytesLimiter creates a RateLimiter that is meant to be used for a bytes-per-interval limit, +// e.g. 250 MB per day. And example of the underlying idea can be found here: https://go.dev/play/p/0ljgzIZQ6dJ +func NewBytesLimiter(bytes int, interval time.Duration) *RateLimiter { + return NewRateLimiter(rate.Limit(bytes)*rate.Every(interval), bytes) } -// Limit returns the defined limit -func (l *Limiter) Limit() int64 { - return l.limit +// Allow adds n to the limiters internal value, but only if the limit has not been reached. If the limit was +// exceeded after adding n, ErrLimitReached is returned. +func (l *RateLimiter) Allow(n int64) error { + if n <= 0 { + return nil // No-op. Can't take back bytes you're written! + } + if !l.limiter.AllowN(time.Now(), int(n)) { + return ErrLimitReached + } + return nil } // LimitWriter implements an io.Writer that will pass through all Write calls to the underlying @@ -67,12 +80,12 @@ func (l *Limiter) Limit() int64 { type LimitWriter struct { w io.Writer written int64 - limiters []*Limiter + limiters []Limiter mu sync.Mutex } // NewLimitWriter creates a new LimitWriter -func NewLimitWriter(w io.Writer, limiters ...*Limiter) *LimitWriter { +func NewLimitWriter(w io.Writer, limiters ...Limiter) *LimitWriter { return &LimitWriter{ w: w, limiters: limiters, @@ -84,9 +97,9 @@ func (w *LimitWriter) Write(p []byte) (n int, err error) { w.mu.Lock() defer w.mu.Unlock() for i := 0; i < len(w.limiters); i++ { - if err := w.limiters[i].Add(int64(len(p))); err != nil { + if err := w.limiters[i].Allow(int64(len(p))); err != nil { for j := i - 1; j >= 0; j-- { - w.limiters[j].Sub(int64(len(p))) + w.limiters[j].Allow(-int64(len(p))) // Revert limiters limits if allowed } return 0, ErrLimitReached } diff --git a/util/limit_test.go b/util/limit_test.go index 4f07e00f..53e10b78 100644 --- a/util/limit_test.go +++ b/util/limit_test.go @@ -2,34 +2,51 @@ package util import ( "bytes" + "github.com/stretchr/testify/require" "testing" + "time" ) -func TestLimiter_Add(t *testing.T) { - l := NewLimiter(10) - if err := l.Add(5); err != nil { +func TestFixedLimiter_Add(t *testing.T) { + l := NewFixedLimiter(10) + if err := l.Allow(5); err != nil { t.Fatal(err) } - if err := l.Add(5); err != nil { + if err := l.Allow(5); err != nil { t.Fatal(err) } - if err := l.Add(5); err != ErrLimitReached { + if err := l.Allow(5); err != ErrLimitReached { t.Fatalf("expected ErrLimitReached, got %#v", err) } } -func TestLimiter_AddSet(t *testing.T) { - l := NewLimiter(10) - l.Add(5) - if l.Value() != 5 { - t.Fatalf("expected value to be %d, got %d", 5, l.Value()) +func TestFixedLimiter_AddSub(t *testing.T) { + l := NewFixedLimiter(10) + l.Allow(5) + if l.value != 5 { + t.Fatalf("expected value to be %d, got %d", 5, l.value) } - l.Set(7) - if l.Value() != 7 { - t.Fatalf("expected value to be %d, got %d", 7, l.Value()) + l.Allow(-2) + if l.value != 3 { + t.Fatalf("expected value to be %d, got %d", 7, l.value) } } +func TestBytesLimiter_Add_Simple(t *testing.T) { + l := NewBytesLimiter(250*1024*1024, 24*time.Hour) // 250 MB per 24h + require.Nil(t, l.Allow(100*1024*1024)) + require.Nil(t, l.Allow(100*1024*1024)) + require.Equal(t, ErrLimitReached, l.Allow(300*1024*1024)) +} + +func TestBytesLimiter_Add_Wait(t *testing.T) { + l := NewBytesLimiter(250*1024*1024, 24*time.Hour) // 250 MB per 24h (~ 303 bytes per 100ms) + require.Nil(t, l.Allow(250*1024*1024)) + require.Equal(t, ErrLimitReached, l.Allow(400)) + time.Sleep(200 * time.Millisecond) + require.Nil(t, l.Allow(400)) +} + func TestLimitWriter_WriteNoLimiter(t *testing.T) { var buf bytes.Buffer lw := NewLimitWriter(&buf) @@ -46,7 +63,7 @@ func TestLimitWriter_WriteNoLimiter(t *testing.T) { func TestLimitWriter_WriteOneLimiter(t *testing.T) { var buf bytes.Buffer - l := NewLimiter(10) + l := NewFixedLimiter(10) lw := NewLimitWriter(&buf, l) if _, err := lw.Write(make([]byte, 10)); err != nil { t.Fatal(err) @@ -57,15 +74,15 @@ func TestLimitWriter_WriteOneLimiter(t *testing.T) { if buf.Len() != 10 { t.Fatalf("expected buffer length to be %d, got %d", 10, buf.Len()) } - if l.Value() != 10 { - t.Fatalf("expected limiter value to be %d, got %d", 10, l.Value()) + if l.value != 10 { + t.Fatalf("expected limiter value to be %d, got %d", 10, l.value) } } func TestLimitWriter_WriteTwoLimiters(t *testing.T) { var buf bytes.Buffer - l1 := NewLimiter(11) - l2 := NewLimiter(9) + l1 := NewFixedLimiter(11) + l2 := NewFixedLimiter(9) lw := NewLimitWriter(&buf, l1, l2) if _, err := lw.Write(make([]byte, 8)); err != nil { t.Fatal(err) @@ -76,10 +93,47 @@ func TestLimitWriter_WriteTwoLimiters(t *testing.T) { if buf.Len() != 8 { t.Fatalf("expected buffer length to be %d, got %d", 8, buf.Len()) } - if l1.Value() != 8 { - t.Fatalf("expected limiter 1 value to be %d, got %d", 8, l1.Value()) + if l1.value != 8 { + t.Fatalf("expected limiter 1 value to be %d, got %d", 8, l1.value) } - if l2.Value() != 8 { - t.Fatalf("expected limiter 2 value to be %d, got %d", 8, l2.Value()) + if l2.value != 8 { + t.Fatalf("expected limiter 2 value to be %d, got %d", 8, l2.value) } } + +func TestLimitWriter_WriteTwoDifferentLimiters(t *testing.T) { + var buf bytes.Buffer + l1 := NewFixedLimiter(32) + l2 := NewBytesLimiter(8, 200*time.Millisecond) + lw := NewLimitWriter(&buf, l1, l2) + _, err := lw.Write(make([]byte, 8)) + require.Nil(t, err) + _, err = lw.Write(make([]byte, 4)) + require.Equal(t, ErrLimitReached, err) +} + +func TestLimitWriter_WriteTwoDifferentLimiters_Wait(t *testing.T) { + var buf bytes.Buffer + l1 := NewFixedLimiter(32) + l2 := NewBytesLimiter(8, 200*time.Millisecond) + lw := NewLimitWriter(&buf, l1, l2) + _, err := lw.Write(make([]byte, 8)) + require.Nil(t, err) + time.Sleep(250 * time.Millisecond) + _, err = lw.Write(make([]byte, 8)) + require.Nil(t, err) + _, err = lw.Write(make([]byte, 4)) + require.Equal(t, ErrLimitReached, err) +} + +func TestLimitWriter_WriteTwoDifferentLimiters_Wait_FixedLimiterFail(t *testing.T) { + var buf bytes.Buffer + l1 := NewFixedLimiter(11) // <<< This fails below + l2 := NewBytesLimiter(8, 200*time.Millisecond) + lw := NewLimitWriter(&buf, l1, l2) + _, err := lw.Write(make([]byte, 8)) + require.Nil(t, err) + time.Sleep(250 * time.Millisecond) + _, err = lw.Write(make([]byte, 8)) // <<< FixedLimiter fails + require.Equal(t, ErrLimitReached, err) +}