mirror of
https://github.com/binwiederhier/ntfy.git
synced 2025-06-09 14:34:36 +02:00
Making RateLimiter and FixedLimiter, so they can both work with LimitWriter
This commit is contained in:
parent
f6b9ebb693
commit
c76e55a1c8
7 changed files with 127 additions and 67 deletions
|
@ -40,7 +40,7 @@ func newFileCache(dir string, totalSizeLimit int64, fileSizeLimit int64) (*fileC
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (c *fileCache) Write(id string, in io.Reader, limiters ...*util.Limiter) (int64, error) {
|
||||
func (c *fileCache) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
|
||||
if !fileIDRegex.MatchString(id) {
|
||||
return 0, errInvalidFileID
|
||||
}
|
||||
|
@ -53,7 +53,7 @@ func (c *fileCache) Write(id string, in io.Reader, limiters ...*util.Limiter) (i
|
|||
return 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
limiters = append(limiters, util.NewLimiter(c.Remaining()), util.NewLimiter(c.fileSizeLimit))
|
||||
limiters = append(limiters, util.NewFixedLimiter(c.Remaining()), util.NewFixedLimiter(c.fileSizeLimit))
|
||||
limitWriter := util.NewLimitWriter(f, limiters...)
|
||||
size, err := io.Copy(limitWriter, in)
|
||||
if err != nil {
|
||||
|
|
|
@ -16,7 +16,7 @@ var (
|
|||
|
||||
func TestFileCache_Write_Success(t *testing.T) {
|
||||
dir, c := newTestFileCache(t)
|
||||
size, err := c.Write("abc", strings.NewReader("normal file"), util.NewLimiter(999))
|
||||
size, err := c.Write("abc", strings.NewReader("normal file"), util.NewFixedLimiter(999))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(11), size)
|
||||
require.Equal(t, "normal file", readFile(t, dir+"/abc"))
|
||||
|
@ -64,7 +64,7 @@ func TestFileCache_Write_FailedFileSizeLimit(t *testing.T) {
|
|||
|
||||
func TestFileCache_Write_FailedAdditionalLimiter(t *testing.T) {
|
||||
dir, c := newTestFileCache(t)
|
||||
_, err := c.Write("abc", bytes.NewReader(make([]byte, 1001)), util.NewLimiter(1000))
|
||||
_, err := c.Write("abc", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000))
|
||||
require.Equal(t, util.ErrLimitReached, err)
|
||||
require.NoFileExists(t, dir+"/abc")
|
||||
}
|
||||
|
|
|
@ -648,7 +648,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
|
|||
if m.Message == "" {
|
||||
m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name)
|
||||
}
|
||||
m.Attachment.Size, err = s.fileCache.Write(m.ID, body, util.NewLimiter(remainingVisitorAttachmentSize))
|
||||
m.Attachment.Size, err = s.fileCache.Write(m.ID, body, util.NewFixedLimiter(remainingVisitorAttachmentSize))
|
||||
if err == util.ErrLimitReached {
|
||||
return errHTTPBadRequestAttachmentTooLarge
|
||||
} else if err != nil {
|
||||
|
|
|
@ -909,13 +909,6 @@ func toMessage(t *testing.T, s string) *message {
|
|||
return &m
|
||||
}
|
||||
|
||||
func tempFile(t *testing.T, length int) (filename string, content string) {
|
||||
filename = filepath.Join(t.TempDir(), util.RandomString(10))
|
||||
content = util.RandomString(length)
|
||||
require.Nil(t, os.WriteFile(filename, []byte(content), 0600))
|
||||
return
|
||||
}
|
||||
|
||||
func toHTTPError(t *testing.T, s string) *errHTTP {
|
||||
var e errHTTP
|
||||
require.Nil(t, json.NewDecoder(strings.NewReader(s)).Decode(&e))
|
||||
|
|
|
@ -24,7 +24,7 @@ type visitor struct {
|
|||
config *Config
|
||||
ip string
|
||||
requests *rate.Limiter
|
||||
subscriptions *util.Limiter
|
||||
subscriptions util.Limiter
|
||||
emails *rate.Limiter
|
||||
seen time.Time
|
||||
mu sync.Mutex
|
||||
|
@ -35,7 +35,7 @@ func newVisitor(conf *Config, ip string) *visitor {
|
|||
config: conf,
|
||||
ip: ip,
|
||||
requests: rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst),
|
||||
subscriptions: util.NewLimiter(int64(conf.VisitorSubscriptionLimit)),
|
||||
subscriptions: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
|
||||
emails: rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst),
|
||||
seen: time.Now(),
|
||||
}
|
||||
|
@ -62,7 +62,7 @@ func (v *visitor) EmailAllowed() error {
|
|||
func (v *visitor) SubscriptionAllowed() error {
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
if err := v.subscriptions.Add(1); err != nil {
|
||||
if err := v.subscriptions.Allow(1); err != nil {
|
||||
return errVisitorLimitReached
|
||||
}
|
||||
return nil
|
||||
|
@ -71,7 +71,7 @@ func (v *visitor) SubscriptionAllowed() error {
|
|||
func (v *visitor) RemoveSubscription() {
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
v.subscriptions.Sub(1)
|
||||
v.subscriptions.Allow(-1)
|
||||
}
|
||||
|
||||
func (v *visitor) Keepalive() {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue