1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2024-12-22 17:52:30 +01:00

Rate limits make sense now!

This commit is contained in:
binwiederhier 2023-01-26 22:57:18 -05:00
parent a036814d98
commit c874a641df
17 changed files with 365 additions and 205 deletions

View file

@ -77,6 +77,7 @@ var flagsServe = append(
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-request-limit-burst", Aliases: []string{"visitor_request_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_BURST"}, Value: server.DefaultVisitorRequestLimitBurst, Usage: "initial limit of requests per visitor"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-request-limit-replenish", Aliases: []string{"visitor_request_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_REPLENISH"}, Value: server.DefaultVisitorRequestLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "visitor-request-limit-exempt-hosts", Aliases: []string{"visitor_request_limit_exempt_hosts"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_EXEMPT_HOSTS"}, Value: "", Usage: "hostnames and/or IP addresses of hosts that will be exempt from the visitor request limit"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-message-daily-limit", Aliases: []string{"visitor_message_daily_limit"}, EnvVars: []string{"NTFY_VISITOR_MESSAGE_DAILY_LIMIT"}, Value: server.DefaultVisitorMessageDailyLimit, Usage: "max messages per visitor per day, derived from request limit if unset"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", Aliases: []string{"visitor_email_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-email-limit-replenish", Aliases: []string{"visitor_email_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: server.DefaultVisitorEmailLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"behind_proxy", "P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}),
@ -150,6 +151,7 @@ func execServe(c *cli.Context) error {
visitorRequestLimitBurst := c.Int("visitor-request-limit-burst")
visitorRequestLimitReplenish := c.Duration("visitor-request-limit-replenish")
visitorRequestLimitExemptHosts := util.SplitNoEmpty(c.String("visitor-request-limit-exempt-hosts"), ",")
visitorMessageDailyLimit := c.Int("visitor-message-daily-limit")
visitorEmailLimitBurst := c.Int("visitor-email-limit-burst")
visitorEmailLimitReplenish := c.Duration("visitor-email-limit-replenish")
behindProxy := c.Bool("behind-proxy")
@ -289,6 +291,7 @@ func execServe(c *cli.Context) error {
conf.VisitorRequestLimitBurst = visitorRequestLimitBurst
conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish
conf.VisitorRequestExemptIPAddrs = visitorRequestLimitExemptIPs
conf.VisitorMessageDailyLimit = visitorMessageDailyLimit
conf.VisitorEmailLimitBurst = visitorEmailLimitBurst
conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish
conf.BehindProxy = behindProxy

View file

@ -44,6 +44,7 @@ const (
DefaultVisitorSubscriptionLimit = 30
DefaultVisitorRequestLimitBurst = 60
DefaultVisitorRequestLimitReplenish = 5 * time.Second
DefaultVisitorMessageDailyLimit = 0
DefaultVisitorEmailLimitBurst = 16
DefaultVisitorEmailLimitReplenish = time.Hour
DefaultVisitorAccountCreationLimitBurst = 3
@ -105,6 +106,7 @@ type Config struct {
VisitorRequestLimitBurst int
VisitorRequestLimitReplenish time.Duration
VisitorRequestExemptIPAddrs []netip.Prefix
VisitorMessageDailyLimit int
VisitorEmailLimitBurst int
VisitorEmailLimitReplenish time.Duration
VisitorAccountCreationLimitBurst int
@ -171,6 +173,7 @@ func NewConfig() *Config {
VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst,
VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish,
VisitorRequestExemptIPAddrs: make([]netip.Prefix, 0),
VisitorMessageDailyLimit: DefaultVisitorMessageDailyLimit,
VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst,
VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish,
VisitorAccountCreationLimitBurst: DefaultVisitorAccountCreationLimitBurst,

View file

@ -75,10 +75,10 @@ var (
errHTTPTooManyRequestsLimitEmails = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsLimitTotalTopics = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsLimitAttachmentBandwidth = &errHTTP{42905, http.StatusTooManyRequests, "limit reached: daily bandwidth", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsLimitAttachmentBandwidth = &errHTTP{42905, http.StatusTooManyRequests, "limit reached: daily bandwidth reached", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsLimitAccountCreation = &errHTTP{42906, http.StatusTooManyRequests, "limit reached: too many accounts created", "https://ntfy.sh/docs/publish/#limitations"} // FIXME document limit
errHTTPTooManyRequestsLimitReservations = &errHTTP{42907, http.StatusTooManyRequests, "limit reached: too many topic reservations for this user", ""}
errHTTPTooManyRequestsLimitMessages = &errHTTP{42908, http.StatusTooManyRequests, "limit reached: too many messages", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsLimitMessages = &errHTTP{42908, http.StatusTooManyRequests, "limit reached: daily message quota reached", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", ""}
errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/"}

View file

@ -38,10 +38,9 @@ import (
TODO
--
- HIGH Rate limiting: dailyLimitToRate is wrong? + TESTS
- HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
- HIGH Rate limiting: Delete visitor when tier is changed to refresh rate limiters
- HIGH Rate limiting: When ResetStats() is run, reset messagesLimiter (and others)?
- MEDIUM Rate limiting: Test daily message quota read from database initially
- MEDIUM: Races with v.user (see publishSyncEventAsync test)
- MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
- MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade)
@ -57,7 +56,6 @@ Make sure account endpoints make sense for admins
Tests:
- Payment endpoints (make mocks)
- test that the visitor is based on the IP address when a user has no tier
*/
// Server is the main server, providing the UI and API for ntfy
@ -308,7 +306,7 @@ func (s *Server) Stop() {
}
func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
v, err := s.visitor(r) // Note: Always returns v, even when error is returned
v, err := s.maybeAuthenticate(r) // Note: Always returns v, even when error is returned
if err == nil {
log.Debug("%s Dispatching request", logHTTPPrefix(v, r))
if log.IsTrace() {
@ -563,7 +561,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
if v.user != nil {
m.User = v.user.ID
}
m.Expires = time.Now().Add(v.Limits().MessagesExpiryDuration).Unix()
m.Expires = time.Now().Add(v.Limits().MessageExpiryDuration).Unix()
if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
return nil, err
}
@ -601,7 +599,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
}
v.IncrementMessages()
if s.userManager != nil && v.user != nil {
s.userManager.EnqueueStats(v.user)
s.userManager.EnqueueStats(v.user) // FIXME this makes no sense for tier-less users
}
s.mu.Lock()
s.messages++
@ -1382,8 +1380,10 @@ func (s *Server) runStatsResetter() {
log.Debug("Stats resetter: Waiting until %v to reset visitor stats", runAt)
select {
case <-timer.C:
log.Debug("Stats resetter: Running")
s.resetStats()
case <-s.closeChan:
log.Debug("Stats resetter: Stopping timer")
timer.Stop()
return
}
@ -1440,17 +1440,15 @@ func (s *Server) sendDelayedMessages() error {
return err
}
for _, m := range messages {
var v *visitor
var u *user.User
if s.userManager != nil && m.User != "" {
u, err := s.userManager.User(m.User)
u, err = s.userManager.User(m.User)
if err != nil {
log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
log.Warn("Error sending delayed message %s: %s", m.ID, err.Error())
continue
}
v = s.visitorFromUser(u, m.Sender)
} else {
v = s.visitorFromIP(m.Sender)
}
v := s.visitor(m.Sender, u)
if err := s.sendDelayedMessage(v, m); err != nil {
log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
}
@ -1588,20 +1586,16 @@ func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc
}
}
// visitor creates or retrieves a rate.Limiter for the given visitor.
// maybeAuthenticate creates or retrieves a rate.Limiter for the given visitor.
// Note that this function will always return a visitor, even if an error occurs.
func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
func (s *Server) maybeAuthenticate(r *http.Request) (v *visitor, err error) {
ip := extractIPAddress(r, s.config.BehindProxy)
var u *user.User // may stay nil if no auth header!
if u, err = s.authenticate(r); err != nil {
log.Debug("authentication failed: %s", err.Error())
err = errHTTPUnauthorized // Always return visitor, even when error occurs!
}
if u != nil {
v = s.visitorFromUser(u, ip)
} else {
v = s.visitorFromIP(ip)
}
v = s.visitor(ip, u)
v.SetUser(u) // Update visitor user with latest from database!
return v, err // Always return visitor, even when error occurs!
}
@ -1645,26 +1639,19 @@ func (s *Server) authenticateBearerAuth(value string) (user *user.User, err erro
return s.userManager.AuthenticateToken(token)
}
func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *user.User) *visitor {
func (s *Server) visitor(ip netip.Addr, user *user.User) *visitor {
s.mu.Lock()
defer s.mu.Unlock()
v, exists := s.visitors[visitorID]
id := visitorID(ip, user)
v, exists := s.visitors[id]
if !exists {
s.visitors[visitorID] = newVisitor(s.config, s.messageCache, s.userManager, ip, user)
return s.visitors[visitorID]
s.visitors[id] = newVisitor(s.config, s.messageCache, s.userManager, ip, user)
return s.visitors[id]
}
v.Keepalive()
return v
}
func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
return s.visitorFromID(fmt.Sprintf("ip:%s", ip.String()), ip, nil)
}
func (s *Server) visitorFromUser(user *user.User, ip netip.Addr) *visitor {
return s.visitorFromID(fmt.Sprintf("user:%s", user.ID), ip, user)
}
func (s *Server) writeJSON(w http.ResponseWriter, v any) error {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests

View file

@ -200,6 +200,12 @@
# visitor-request-limit-replenish: "5s"
# visitor-request-limit-exempt-hosts: ""
# Rate limiting: Hard daily limit of messages per visitor and day. The limit is reset
# every day at midnight UTC. If the limit is not set (or set to zero), the request
# limit (see above) governs the upper limit.
#
# visitor-message-daily-limit: 0
# Rate limiting: Allowed emails per visitor:
# - visitor-email-limit-burst is the initial bucket of emails each visitor has
# - visitor-email-limit-replenish is the rate at which the bucket is refilled

View file

@ -23,6 +23,9 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
} else if v.user != nil {
return errHTTPUnauthorized // Cannot create account from user context
}
if err := v.AccountCreationAllowed(); err != nil {
return errHTTPTooManyRequestsLimitAccountCreation
}
}
newAccount, err := readJSONWithLimit[apiAccountCreateRequest](r.Body, jsonBodyBytesLimit)
if err != nil {
@ -31,9 +34,6 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
if existingUser, _ := s.userManager.User(newAccount.Username); existingUser != nil {
return errHTTPConflictUserExists
}
if err := v.AccountCreationAllowed(); err != nil {
return errHTTPTooManyRequestsLimitAccountCreation
}
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { // TODO this should return a User
return err
}
@ -49,9 +49,9 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis
response := &apiAccountResponse{
Limits: &apiAccountLimits{
Basis: string(limits.Basis),
Messages: limits.MessagesLimit,
MessagesExpiryDuration: int64(limits.MessagesExpiryDuration.Seconds()),
Emails: limits.EmailsLimit,
Messages: limits.MessageLimit,
MessagesExpiryDuration: int64(limits.MessageExpiryDuration.Seconds()),
Emails: limits.EmailLimit,
Reservations: limits.ReservationsLimit,
AttachmentTotalSize: limits.AttachmentTotalSizeLimit,
AttachmentFileSize: limits.AttachmentFileSizeLimit,
@ -344,7 +344,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
reservations, err := s.userManager.ReservationsCount(v.user.Name)
if err != nil {
return err
} else if reservations >= v.user.Tier.ReservationsLimit {
} else if reservations >= v.user.Tier.ReservationLimit {
return errHTTPTooManyRequestsLimitReservations
}
}

View file

@ -410,10 +410,10 @@ func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) {
// Create a tier
require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "pro",
MessagesLimit: 123,
MessagesExpiryDuration: 86400 * time.Second,
EmailsLimit: 32,
ReservationsLimit: 2,
MessageLimit: 123,
MessageExpiryDuration: 86400 * time.Second,
EmailLimit: 32,
ReservationLimit: 2,
AttachmentFileSizeLimit: 1231231,
AttachmentTotalSizeLimit: 123123,
AttachmentExpiryDuration: 10800 * time.Second,
@ -491,9 +491,9 @@ func TestAccount_Reservation_PublishByAnonymousFails(t *testing.T) {
require.Equal(t, 200, rr.Code)
require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "pro",
MessagesLimit: 20,
ReservationsLimit: 2,
Code: "pro",
MessageLimit: 20,
ReservationLimit: 2,
}))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
@ -525,9 +525,9 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
require.Equal(t, 200, rr.Code)
require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "pro",
MessagesLimit: 20,
ReservationsLimit: 2,
Code: "pro",
MessageLimit: 20,
ReservationLimit: 2,
}))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
@ -591,10 +591,10 @@ func TestAccount_Tier_Create(t *testing.T) {
require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "pro",
Name: "Pro",
MessagesLimit: 123,
MessagesExpiryDuration: 86400 * time.Second,
EmailsLimit: 32,
ReservationsLimit: 2,
MessageLimit: 123,
MessageExpiryDuration: 86400 * time.Second,
EmailLimit: 32,
ReservationLimit: 2,
AttachmentFileSizeLimit: 1231231,
AttachmentTotalSizeLimit: 123123,
AttachmentExpiryDuration: 10800 * time.Second,
@ -616,10 +616,10 @@ func TestAccount_Tier_Create(t *testing.T) {
require.True(t, strings.HasPrefix(ti.ID, "ti_"))
require.Equal(t, "pro", ti.Code)
require.Equal(t, "Pro", ti.Name)
require.Equal(t, int64(123), ti.MessagesLimit)
require.Equal(t, 86400*time.Second, ti.MessagesExpiryDuration)
require.Equal(t, int64(32), ti.EmailsLimit)
require.Equal(t, int64(2), ti.ReservationsLimit)
require.Equal(t, int64(123), ti.MessageLimit)
require.Equal(t, 86400*time.Second, ti.MessageExpiryDuration)
require.Equal(t, int64(32), ti.EmailLimit)
require.Equal(t, int64(2), ti.ReservationLimit)
require.Equal(t, int64(1231231), ti.AttachmentFileSizeLimit)
require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit)
require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration)

View file

@ -60,15 +60,15 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
if err != nil {
return err
}
freeTier := defaultVisitorLimits(s.config)
freeTier := configBasedVisitorLimits(s.config)
response := []*apiAccountBillingTier{
{
// This is a bit of a hack: This is the "Free" tier. It has no tier code, name or price.
Limits: &apiAccountLimits{
Basis: string(visitorLimitBasisIP),
Messages: freeTier.MessagesLimit,
MessagesExpiryDuration: int64(freeTier.MessagesExpiryDuration.Seconds()),
Emails: freeTier.EmailsLimit,
Messages: freeTier.MessageLimit,
MessagesExpiryDuration: int64(freeTier.MessageExpiryDuration.Seconds()),
Emails: freeTier.EmailLimit,
Reservations: freeTier.ReservationsLimit,
AttachmentTotalSize: freeTier.AttachmentTotalSizeLimit,
AttachmentFileSize: freeTier.AttachmentFileSizeLimit,
@ -91,10 +91,10 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
Price: priceStr,
Limits: &apiAccountLimits{
Basis: string(visitorLimitBasisTier),
Messages: tier.MessagesLimit,
MessagesExpiryDuration: int64(tier.MessagesExpiryDuration.Seconds()),
Emails: tier.EmailsLimit,
Reservations: tier.ReservationsLimit,
Messages: tier.MessageLimit,
MessagesExpiryDuration: int64(tier.MessageExpiryDuration.Seconds()),
Emails: tier.EmailLimit,
Reservations: tier.ReservationLimit,
AttachmentTotalSize: tier.AttachmentTotalSizeLimit,
AttachmentFileSize: tier.AttachmentFileSizeLimit,
AttachmentExpiryDuration: int64(tier.AttachmentExpiryDuration.Seconds()),
@ -336,7 +336,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
return err
}
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
return nil
}
@ -355,14 +355,14 @@ func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMe
if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, nil, ev.Customer, "", "", 0, 0); err != nil {
return err
}
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
return nil
}
func (s *Server) updateSubscriptionAndTier(logPrefix string, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
reservationsLimit := visitorDefaultReservationsLimit
if tier != nil {
reservationsLimit = tier.ReservationsLimit
reservationsLimit = tier.ReservationLimit
}
if err := s.maybeRemoveMessagesAndExcessReservations(logPrefix, u, reservationsLimit); err != nil {
return err

View file

@ -5,11 +5,14 @@ import (
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/stripe/stripe-go/v74"
"golang.org/x/time/rate"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
"io"
"net/netip"
"path/filepath"
"strings"
"sync"
"testing"
"time"
)
@ -48,10 +51,10 @@ func TestPayments_Tiers(t *testing.T) {
ID: "ti_123",
Code: "pro",
Name: "Pro",
MessagesLimit: 1000,
MessagesExpiryDuration: time.Hour,
EmailsLimit: 123,
ReservationsLimit: 777,
MessageLimit: 1000,
MessageExpiryDuration: time.Hour,
EmailLimit: 123,
ReservationLimit: 777,
AttachmentFileSizeLimit: 999,
AttachmentTotalSizeLimit: 888,
AttachmentExpiryDuration: time.Minute,
@ -61,10 +64,10 @@ func TestPayments_Tiers(t *testing.T) {
ID: "ti_444",
Code: "business",
Name: "Business",
MessagesLimit: 2000,
MessagesExpiryDuration: 10 * time.Hour,
EmailsLimit: 123123,
ReservationsLimit: 777333,
MessageLimit: 2000,
MessageExpiryDuration: 10 * time.Hour,
EmailLimit: 123123,
ReservationLimit: 777333,
AttachmentFileSizeLimit: 999111,
AttachmentTotalSizeLimit: 888111,
AttachmentExpiryDuration: time.Hour,
@ -238,9 +241,14 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
require.Equal(t, 401, rr.Code)
}
func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *testing.T) {
// This tests a successful checkout flow (not a paying customer -> paying customer),
// and also tests that during the upgrade we are RESETTING THE RATE LIMITS of the existing user.
func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *testing.T) {
// This test is too overloaded, but it's also a great end-to-end a test.
//
// It tests:
// - A successful checkout flow (not a paying customer -> paying customer)
// - Tier-changes reset the rate limits for the user
// - The request limits for tier-less user and a tier-user
// - The message limits for a tier-user
stripeMock := &testStripeAPI{}
defer stripeMock.AssertExpectations(t)
@ -248,19 +256,26 @@ func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *test
c := newTestConfigWithAuthFile(t)
c.StripeSecretKey = "secret key"
c.StripeWebhookKey = "webhook key"
c.VisitorRequestLimitBurst = 10
c.VisitorRequestLimitBurst = 5
c.VisitorRequestLimitReplenish = time.Hour
c.CacheStartupQueries = `
pragma journal_mode = WAL;
pragma synchronous = normal;
pragma temp_store = memory;
`
c.CacheBatchSize = 500
c.CacheBatchTimeout = time.Second
s := newTestServer(t, c)
s.stripe = stripeMock
// Create a user with a Stripe subscription and 3 reservations
require.Nil(t, s.userManager.CreateTier(&user.Tier{
ID: "ti_123",
Code: "starter",
StripePriceID: "price_1234",
ReservationsLimit: 1,
MessagesLimit: 100,
MessagesExpiryDuration: time.Hour,
ID: "ti_123",
Code: "starter",
StripePriceID: "price_1234",
ReservationLimit: 1,
MessageLimit: 220, // 220 * 5% = 11 requests before rate limiting kicks in
MessageExpiryDuration: time.Hour,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) // No tier
u, err := s.userManager.User("phil")
@ -298,7 +313,7 @@ func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *test
Return(&stripe.Customer{}, nil)
// Send messages until rate limit of free tier is hit
for i := 0; i < 10; i++ {
for i := 0; i < 5; i++ {
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
@ -323,10 +338,9 @@ func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *test
require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix())
require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
// FIXME FIXME This test is broken, because the rate limit logic is unclear!
// Now for the fun part: Verify that new rate limits are immediately applied
for i := 0; i < 100; i++ {
// This only tests the request limiter, which kicks in before the message limiter.
for i := 0; i < 11; i++ {
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
@ -336,6 +350,37 @@ func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *test
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 429, rr.Code)
// Now let's test the message limiter by faking a ridiculously generous rate limiter
v := s.visitor(netip.MustParseAddr("9.9.9.9"), u)
v.requestLimiter = rate.NewLimiter(rate.Every(time.Millisecond), 1000000)
var wg sync.WaitGroup
for i := 0; i < 209; i++ {
wg.Add(1)
go func() {
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
wg.Done()
}()
}
wg.Wait()
rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 429, rr.Code)
// And now let's cross-check that the stats are correct too
rr = request(t, s, "GET", "/v1/account", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
account, _ := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body))
require.Equal(t, int64(220), account.Limits.Messages)
require.Equal(t, int64(220), account.Stats.Messages)
require.Equal(t, int64(0), account.Stats.MessagesRemaining)
}
func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
@ -363,9 +408,9 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
ID: "ti_1",
Code: "starter",
StripePriceID: "price_1234", // !
ReservationsLimit: 1, // !
MessagesLimit: 100,
MessagesExpiryDuration: time.Hour,
ReservationLimit: 1, // !
MessageLimit: 100,
MessageExpiryDuration: time.Hour,
AttachmentExpiryDuration: time.Hour,
AttachmentFileSizeLimit: 1000000,
AttachmentTotalSizeLimit: 1000000,
@ -375,9 +420,9 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
ID: "ti_2",
Code: "pro",
StripePriceID: "price_1111", // !
ReservationsLimit: 3, // !
MessagesLimit: 200,
MessagesExpiryDuration: time.Hour,
ReservationLimit: 3, // !
MessageLimit: 200,
MessageExpiryDuration: time.Hour,
AttachmentExpiryDuration: time.Hour,
AttachmentFileSizeLimit: 1000000,
AttachmentTotalSizeLimit: 1000000,

View file

@ -8,7 +8,6 @@ import (
"fmt"
"heckel.io/ntfy/user"
"io"
"log"
"math/rand"
"net/http"
"net/http/httptest"
@ -22,9 +21,14 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/log"
"heckel.io/ntfy/util"
)
func init() {
// log.SetLevel(log.DebugLevel)
}
func TestServer_PublishAndPoll(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
@ -742,16 +746,31 @@ func TestServer_Auth_ViaQuery(t *testing.T) {
require.Equal(t, 401, response.Code)
}
func TestServer_StatsResetter(t *testing.T) {
func TestServer_StatsResetter_User_Without_Tier(t *testing.T) {
// This tests the stats resetter for
// - an anonymous user
// - a user without a tier (treated like the same as the anonymous user)
// - a user with a tier
c := newTestConfigWithAuthFile(t)
c.AuthDefault = user.PermissionDenyAll
c.VisitorStatsResetTime = time.Now().Add(2 * time.Second)
s := newTestServer(t, c)
go s.runStatsResetter()
// Create user with tier (tieruser) and user without tier (phil)
require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "test",
MessageLimit: 5,
MessageExpiryDuration: -5 * time.Second, // Second, what a hack!
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AllowAccess("phil", "mytopic", user.PermissionReadWrite))
require.Nil(t, s.userManager.AddUser("tieruser", "tieruser", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("tieruser", "test"))
// Send an anonymous message
response := request(t, s, "PUT", "/mytopic", "test", nil)
// Send messages from user without tier (phil)
for i := 0; i < 5; i++ {
response := request(t, s, "PUT", "/mytopic", "test", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
@ -759,30 +778,66 @@ func TestServer_StatsResetter(t *testing.T) {
require.Equal(t, 200, response.Code)
}
response := request(t, s, "GET", "/v1/account", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, response.Code)
// Send messages from user with tier
for i := 0; i < 2; i++ {
response := request(t, s, "PUT", "/mytopic", "test", map[string]string{
"Authorization": util.BasicAuth("tieruser", "tieruser"),
})
require.Equal(t, 200, response.Code)
}
// User stats show 10 messages
// User stats show 6 messages (for user without tier)
response = request(t, s, "GET", "/v1/account", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, response.Code)
account, err := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
require.Nil(t, err)
require.Equal(t, int64(5), account.Stats.Messages)
require.Equal(t, int64(6), account.Stats.Messages)
// User stats show 6 messages (for anonymous visitor)
response = request(t, s, "GET", "/v1/account", "", nil)
require.Equal(t, 200, response.Code)
account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
require.Nil(t, err)
require.Equal(t, int64(6), account.Stats.Messages)
// User stats show 2 messages (for user with tier)
response = request(t, s, "GET", "/v1/account", "", map[string]string{
"Authorization": util.BasicAuth("tieruser", "tieruser"),
})
require.Equal(t, 200, response.Code)
account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
require.Nil(t, err)
require.Equal(t, int64(2), account.Stats.Messages)
// Wait for stats resetter to run
time.Sleep(2200 * time.Millisecond)
// User stats show 0 messages now!
response = request(t, s, "GET", "/v1/account", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, response.Code)
account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
require.Nil(t, err)
require.Equal(t, int64(0), account.Stats.Messages)
// Since this is a user without a tier, the anonymous user should have the same stats
response = request(t, s, "GET", "/v1/account", "", nil)
require.Equal(t, 200, response.Code)
account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
require.Nil(t, err)
require.Equal(t, int64(0), account.Stats.Messages)
// User stats show 0 messages (for user with tier)
response = request(t, s, "GET", "/v1/account", "", map[string]string{
"Authorization": util.BasicAuth("tieruser", "tieruser"),
})
require.Equal(t, 200, response.Code)
account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
require.Nil(t, err)
require.Equal(t, int64(0), account.Stats.Messages)
}
type testMailer struct {
@ -1133,9 +1188,9 @@ func TestServer_PublishWithTierBasedMessageLimitAndExpiry(t *testing.T) {
// Create tier with certain limits
require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "test",
MessagesLimit: 5,
MessagesExpiryDuration: -5 * time.Second, // Second, what a hack!
Code: "test",
MessageLimit: 5,
MessageExpiryDuration: -5 * time.Second, // Second, what a hack!
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
@ -1363,8 +1418,8 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) {
sevenDays := time.Duration(604800) * time.Second
require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "test",
MessagesLimit: 10,
MessagesExpiryDuration: sevenDays,
MessageLimit: 10,
MessageExpiryDuration: sevenDays,
AttachmentFileSizeLimit: 50_000,
AttachmentTotalSizeLimit: 200_000,
AttachmentExpiryDuration: sevenDays, // 7 days
@ -1407,8 +1462,8 @@ func TestServer_PublishAttachmentWithTierBasedBandwidthLimit(t *testing.T) {
// Create tier with certain limits
require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "test",
MessagesLimit: 10,
MessagesExpiryDuration: time.Hour,
MessageLimit: 10,
MessageExpiryDuration: time.Hour,
AttachmentFileSizeLimit: 50_000,
AttachmentTotalSizeLimit: 200_000,
AttachmentExpiryDuration: time.Hour,
@ -1450,7 +1505,7 @@ func TestServer_PublishAttachmentWithTierBasedLimits(t *testing.T) {
// Create tier with certain limits
require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "test",
MessagesLimit: 100,
MessageLimit: 100,
AttachmentFileSizeLimit: 50_000,
AttachmentTotalSizeLimit: 200_000,
AttachmentExpiryDuration: 30 * time.Second,
@ -1574,7 +1629,7 @@ func TestServer_Visitor_XForwardedFor_None(t *testing.T) {
r, _ := http.NewRequest("GET", "/bla", nil)
r.RemoteAddr = "8.9.10.11"
r.Header.Set("X-Forwarded-For", " ") // Spaces, not empty!
v, err := s.visitor(r)
v, err := s.maybeAuthenticate(r)
require.Nil(t, err)
require.Equal(t, "8.9.10.11", v.ip.String())
}
@ -1586,7 +1641,7 @@ func TestServer_Visitor_XForwardedFor_Single(t *testing.T) {
r, _ := http.NewRequest("GET", "/bla", nil)
r.RemoteAddr = "8.9.10.11"
r.Header.Set("X-Forwarded-For", "1.1.1.1")
v, err := s.visitor(r)
v, err := s.maybeAuthenticate(r)
require.Nil(t, err)
require.Equal(t, "1.1.1.1", v.ip.String())
}
@ -1598,7 +1653,7 @@ func TestServer_Visitor_XForwardedFor_Multiple(t *testing.T) {
r, _ := http.NewRequest("GET", "/bla", nil)
r.RemoteAddr = "8.9.10.11"
r.Header.Set("X-Forwarded-For", "1.2.3.4 , 2.4.4.2,234.5.2.1 ")
v, err := s.visitor(r)
v, err := s.maybeAuthenticate(r)
require.Nil(t, err)
require.Equal(t, "234.5.2.1", v.ip.String())
}
@ -1611,7 +1666,7 @@ func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) {
s := newTestServer(t, c)
// Add lots of messages
log.Printf("Adding %d messages", count)
log.Info("Adding %d messages", count)
start := time.Now()
messages := make([]*message, 0)
for i := 0; i < count; i++ {
@ -1621,31 +1676,31 @@ func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) {
messages = append(messages, newDefaultMessage(topicID, "some message"))
}
require.Nil(t, s.messageCache.addMessages(messages))
log.Printf("Done: Adding %d messages; took %s", count, time.Since(start).Round(time.Millisecond))
log.Info("Done: Adding %d messages; took %s", count, time.Since(start).Round(time.Millisecond))
// Update stats
statsChan := make(chan bool)
go func() {
log.Printf("Updating stats")
log.Info("Updating stats")
start := time.Now()
s.execManager()
log.Printf("Done: Updating stats; took %s", time.Since(start).Round(time.Millisecond))
log.Info("Done: Updating stats; took %s", time.Since(start).Round(time.Millisecond))
statsChan <- true
}()
time.Sleep(50 * time.Millisecond) // Make sure it starts first
// Publish message (during stats update)
log.Printf("Publishing message")
log.Info("Publishing message")
start = time.Now()
response := request(t, s, "PUT", "/mytopic", "some body", nil)
m := toMessage(t, response.Body.String())
assert.Equal(t, "some body", m.Message)
assert.True(t, time.Since(start) < 100*time.Millisecond)
log.Printf("Done: Publishing message; took %s", time.Since(start).Round(time.Millisecond))
log.Info("Done: Publishing message; took %s", time.Since(start).Round(time.Millisecond))
// Wait for all goroutines
<-statsChan
log.Printf("Done: Waiting for all locks")
log.Info("Done: Waiting for all locks")
}
func newTestConfig(t *testing.T) *Config {

View file

@ -14,16 +14,39 @@ import (
)
const (
// oneDay is an approximation of a day as a time.Duration
oneDay = 24 * time.Hour
// visitorExpungeAfter defines how long a visitor is active before it is removed from memory. This number
// has to be very high to prevent e-mail abuse, but it doesn't really affect the other limits anyway, since
// they are replenished faster (typically).
visitorExpungeAfter = 24 * time.Hour
visitorExpungeAfter = oneDay
// visitorDefaultReservationsLimit is the amount of topic names a user without a tier is allowed to reserve.
// This number is zero, and changing it may have unintended consequences in the web app, or otherwise
visitorDefaultReservationsLimit = int64(0)
)
// Constants used to convert a tier-user's MessageLimit (see user.Tier) into adequate request limiter
// values (token bucket).
//
// Example: Assuming a user.Tier's MessageLimit is 10,000:
// - the allowed burst is 500 (= 10,000 * 5%), which is < 1000 (the max)
// - the replenish rate is 2 * 10,000 / 24 hours
const (
visitorMessageToRequestLimitBurstRate = 0.05
visitorMessageToRequestLimitBurstMax = 1000
visitorMessageToRequestLimitReplenishFactor = 2
)
// Constants used to convert a tier-user's EmailLimit (see user.Tier) into adequate email limiter
// values (token bucket). Example: Assuming a user.Tier's EmailLimit is 200, the allowed burst is
// 40 (= 200 * 20%), which is <150 (the max).
const (
visitorEmailLimitBurstRate = 0.2
visitorEmailLimitBurstMax = 150
)
var (
errVisitorLimitReached = errors.New("limit reached")
)
@ -55,9 +78,13 @@ type visitorInfo struct {
type visitorLimits struct {
Basis visitorLimitBasis
MessagesLimit int64
MessagesExpiryDuration time.Duration
EmailsLimit int64
RequestLimitBurst int
RequestLimitReplenish rate.Limit
MessageLimit int64
MessageExpiryDuration time.Duration
EmailLimit int64
EmailLimitBurst int
EmailLimitReplenish rate.Limit
ReservationsLimit int64
AttachmentTotalSizeLimit int64
AttachmentFileSizeLimit int64
@ -173,7 +200,7 @@ func (v *visitor) SubscriptionAllowed() error {
}
func (v *visitor) AccountCreationAllowed() error {
if v.accountLimiter != nil && !v.accountLimiter.Allow() {
if v.accountLimiter == nil || (v.accountLimiter != nil && !v.accountLimiter.Allow()) {
return errVisitorLimitReached
}
return nil
@ -242,31 +269,6 @@ func (v *visitor) SetUser(u *user.User) {
}
}
func (v *visitor) resetLimiters() {
log.Info("%s Resetting limiters for visitor", v.stringNoLock())
var messagesLimiter, bandwidthLimiter util.Limiter
var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter
if v.user != nil && v.user.Tier != nil {
requestLimiter = rate.NewLimiter(dailyLimitToRate(v.user.Tier.MessagesLimit), v.config.VisitorRequestLimitBurst)
messagesLimiter = util.NewFixedLimiter(v.user.Tier.MessagesLimit)
emailsLimiter = rate.NewLimiter(dailyLimitToRate(v.user.Tier.EmailsLimit), v.config.VisitorEmailLimitBurst)
bandwidthLimiter = util.NewBytesLimiter(int(v.user.Tier.AttachmentBandwidthLimit), 24*time.Hour)
} else {
requestLimiter = rate.NewLimiter(rate.Every(v.config.VisitorRequestLimitReplenish), v.config.VisitorRequestLimitBurst)
messagesLimiter = nil // Message limit is governed by the requestLimiter
emailsLimiter = rate.NewLimiter(rate.Every(v.config.VisitorEmailLimitReplenish), v.config.VisitorEmailLimitBurst)
bandwidthLimiter = util.NewBytesLimiter(int(v.config.VisitorAttachmentDailyBandwidthLimit), 24*time.Hour)
}
if v.user == nil {
accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst)
}
v.requestLimiter = requestLimiter
v.messagesLimiter = messagesLimiter
v.emailsLimiter = emailsLimiter
v.bandwidthLimiter = bandwidthLimiter
v.accountLimiter = accountLimiter
}
// MaybeUserID returns the user ID of the visitor (if any). If this is an anonymous visitor,
// an empty string is returned.
func (v *visitor) MaybeUserID() string {
@ -278,22 +280,71 @@ func (v *visitor) MaybeUserID() string {
return ""
}
func (v *visitor) resetLimiters() {
log.Debug("%s Resetting limiters for visitor", v.stringNoLock())
limits := v.limitsNoLock()
v.requestLimiter = rate.NewLimiter(limits.RequestLimitReplenish, limits.RequestLimitBurst)
v.messagesLimiter = util.NewFixedLimiterWithValue(limits.MessageLimit, v.messages)
v.emailsLimiter = rate.NewLimiter(limits.EmailLimitReplenish, limits.EmailLimitBurst)
v.bandwidthLimiter = util.NewBytesLimiter(int(limits.AttachmentBandwidthLimit), oneDay)
if v.user == nil {
v.accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst)
} else {
v.accountLimiter = nil // Users cannot create accounts when logged in
}
}
func (v *visitor) Limits() *visitorLimits {
v.mu.Lock()
defer v.mu.Unlock()
limits := defaultVisitorLimits(v.config)
return v.limitsNoLock()
}
func (v *visitor) limitsNoLock() *visitorLimits {
if v.user != nil && v.user.Tier != nil {
limits.Basis = visitorLimitBasisTier
limits.MessagesLimit = v.user.Tier.MessagesLimit
limits.MessagesExpiryDuration = v.user.Tier.MessagesExpiryDuration
limits.EmailsLimit = v.user.Tier.EmailsLimit
limits.ReservationsLimit = v.user.Tier.ReservationsLimit
limits.AttachmentTotalSizeLimit = v.user.Tier.AttachmentTotalSizeLimit
limits.AttachmentFileSizeLimit = v.user.Tier.AttachmentFileSizeLimit
limits.AttachmentExpiryDuration = v.user.Tier.AttachmentExpiryDuration
limits.AttachmentBandwidthLimit = v.user.Tier.AttachmentBandwidthLimit
return tierBasedVisitorLimits(v.config, v.user.Tier)
}
return configBasedVisitorLimits(v.config)
}
func tierBasedVisitorLimits(conf *Config, tier *user.Tier) *visitorLimits {
return &visitorLimits{
Basis: visitorLimitBasisTier,
RequestLimitBurst: util.MinMax(int(float64(tier.MessageLimit)*visitorMessageToRequestLimitBurstRate), conf.VisitorRequestLimitBurst, visitorMessageToRequestLimitBurstMax),
RequestLimitReplenish: dailyLimitToRate(tier.MessageLimit * visitorMessageToRequestLimitReplenishFactor),
MessageLimit: tier.MessageLimit,
MessageExpiryDuration: tier.MessageExpiryDuration,
EmailLimit: tier.EmailLimit,
EmailLimitBurst: util.MinMax(int(float64(tier.EmailLimit)*visitorEmailLimitBurstRate), conf.VisitorEmailLimitBurst, visitorEmailLimitBurstMax),
EmailLimitReplenish: dailyLimitToRate(tier.EmailLimit),
ReservationsLimit: tier.ReservationLimit,
AttachmentTotalSizeLimit: tier.AttachmentTotalSizeLimit,
AttachmentFileSizeLimit: tier.AttachmentFileSizeLimit,
AttachmentExpiryDuration: tier.AttachmentExpiryDuration,
AttachmentBandwidthLimit: tier.AttachmentBandwidthLimit,
}
}
func configBasedVisitorLimits(conf *Config) *visitorLimits {
messagesLimit := replenishDurationToDailyLimit(conf.VisitorRequestLimitReplenish) // Approximation!
if conf.VisitorMessageDailyLimit > 0 {
messagesLimit = int64(conf.VisitorMessageDailyLimit)
}
return &visitorLimits{
Basis: visitorLimitBasisIP,
RequestLimitBurst: conf.VisitorRequestLimitBurst,
RequestLimitReplenish: rate.Every(conf.VisitorRequestLimitReplenish),
MessageLimit: messagesLimit,
MessageExpiryDuration: conf.CacheDuration,
EmailLimit: replenishDurationToDailyLimit(conf.VisitorEmailLimitReplenish), // Approximation!
EmailLimitBurst: conf.VisitorEmailLimitBurst,
EmailLimitReplenish: rate.Every(conf.VisitorEmailLimitReplenish),
ReservationsLimit: visitorDefaultReservationsLimit,
AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit,
AttachmentFileSizeLimit: conf.AttachmentFileSizeLimit,
AttachmentExpiryDuration: conf.AttachmentExpiryDuration,
AttachmentBandwidthLimit: conf.VisitorAttachmentDailyBandwidthLimit,
}
return limits
}
func (v *visitor) Info() (*visitorInfo, error) {
@ -321,9 +372,9 @@ func (v *visitor) Info() (*visitorInfo, error) {
limits := v.Limits()
stats := &visitorStats{
Messages: messages,
MessagesRemaining: zeroIfNegative(limits.MessagesLimit - messages),
MessagesRemaining: zeroIfNegative(limits.MessageLimit - messages),
Emails: emails,
EmailsRemaining: zeroIfNegative(limits.EmailsLimit - emails),
EmailsRemaining: zeroIfNegative(limits.EmailLimit - emails),
Reservations: reservations,
ReservationsRemaining: zeroIfNegative(limits.ReservationsLimit - reservations),
AttachmentTotalSize: attachmentsBytesUsed,
@ -343,23 +394,16 @@ func zeroIfNegative(value int64) int64 {
}
func replenishDurationToDailyLimit(duration time.Duration) int64 {
return int64(24 * time.Hour / duration)
return int64(oneDay / duration)
}
func dailyLimitToRate(limit int64) rate.Limit {
return rate.Limit(limit) * rate.Every(24*time.Hour)
return rate.Limit(limit) * rate.Every(oneDay)
}
func defaultVisitorLimits(conf *Config) *visitorLimits {
return &visitorLimits{
Basis: visitorLimitBasisIP,
MessagesLimit: replenishDurationToDailyLimit(conf.VisitorRequestLimitReplenish),
MessagesExpiryDuration: conf.CacheDuration,
EmailsLimit: replenishDurationToDailyLimit(conf.VisitorEmailLimitReplenish),
ReservationsLimit: visitorDefaultReservationsLimit,
AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit,
AttachmentFileSizeLimit: conf.AttachmentFileSizeLimit,
AttachmentExpiryDuration: conf.AttachmentExpiryDuration,
AttachmentBandwidthLimit: conf.VisitorAttachmentDailyBandwidthLimit,
func visitorID(ip netip.Addr, u *user.User) string {
if u != nil && u.Tier != nil {
return fmt.Sprintf("user:%s", u.ID)
}
return fmt.Sprintf("ip:%s", ip.String())
}

View file

@ -709,10 +709,10 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
ID: tierID.String,
Code: tierCode.String,
Name: tierName.String,
MessagesLimit: messagesLimit.Int64,
MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
EmailsLimit: emailsLimit.Int64,
ReservationsLimit: reservationsLimit.Int64,
MessageLimit: messagesLimit.Int64,
MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
EmailLimit: emailsLimit.Int64,
ReservationLimit: reservationsLimit.Int64,
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
@ -845,7 +845,7 @@ func (a *Manager) ChangeTier(username, tier string) error {
t, err := a.Tier(tier)
if err != nil {
return err
} else if err := a.checkReservationsLimit(username, t.ReservationsLimit); err != nil {
} else if err := a.checkReservationsLimit(username, t.ReservationLimit); err != nil {
return err
}
if _, err := a.db.Exec(updateUserTierQuery, tier, username); err != nil {
@ -870,7 +870,7 @@ func (a *Manager) checkReservationsLimit(username string, reservationsLimit int6
if err != nil {
return err
}
if u.Tier != nil && reservationsLimit < u.Tier.ReservationsLimit {
if u.Tier != nil && reservationsLimit < u.Tier.ReservationLimit {
reservations, err := a.Reservations(username)
if err != nil {
return err
@ -999,7 +999,7 @@ func (a *Manager) CreateTier(tier *Tier) error {
if tier.ID == "" {
tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
}
if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil {
if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil {
return err
}
return nil
@ -1070,10 +1070,10 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
ID: id,
Code: code,
Name: name,
MessagesLimit: messagesLimit.Int64,
MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
EmailsLimit: emailsLimit.Int64,
ReservationsLimit: reservationsLimit.Int64,
MessageLimit: messagesLimit.Int64,
MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
EmailLimit: emailsLimit.Int64,
ReservationLimit: reservationsLimit.Int64,
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,

View file

@ -335,10 +335,10 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
Code: "pro",
Name: "ntfy Pro",
StripePriceID: "price123",
MessagesLimit: 5_000,
MessagesExpiryDuration: 3 * 24 * time.Hour,
EmailsLimit: 50,
ReservationsLimit: 5,
MessageLimit: 5_000,
MessageExpiryDuration: 3 * 24 * time.Hour,
EmailLimit: 50,
ReservationLimit: 5,
AttachmentFileSizeLimit: 52428800,
AttachmentTotalSizeLimit: 524288000,
AttachmentExpiryDuration: 24 * time.Hour,
@ -351,10 +351,10 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
require.Nil(t, err)
require.Equal(t, RoleUser, ben.Role)
require.Equal(t, "pro", ben.Tier.Code)
require.Equal(t, int64(5000), ben.Tier.MessagesLimit)
require.Equal(t, 3*24*time.Hour, ben.Tier.MessagesExpiryDuration)
require.Equal(t, int64(50), ben.Tier.EmailsLimit)
require.Equal(t, int64(5), ben.Tier.ReservationsLimit)
require.Equal(t, int64(5000), ben.Tier.MessageLimit)
require.Equal(t, 3*24*time.Hour, ben.Tier.MessageExpiryDuration)
require.Equal(t, int64(50), ben.Tier.EmailLimit)
require.Equal(t, int64(5), ben.Tier.ReservationLimit)
require.Equal(t, int64(52428800), ben.Tier.AttachmentFileSizeLimit)
require.Equal(t, int64(524288000), ben.Tier.AttachmentTotalSizeLimit)
require.Equal(t, 24*time.Hour, ben.Tier.AttachmentExpiryDuration)

View file

@ -62,10 +62,10 @@ type Tier struct {
ID string // Tier identifier (ti_...)
Code string // Code of the tier
Name string // Name of the tier
MessagesLimit int64 // Daily message limit
MessagesExpiryDuration time.Duration // Cache duration for messages
EmailsLimit int64 // Daily email limit
ReservationsLimit int64 // Number of topic reservations allowed by user
MessageLimit int64 // Daily message limit
MessageExpiryDuration time.Duration // Cache duration for messages
EmailLimit int64 // Daily email limit
ReservationLimit int64 // Number of topic reservations allowed by user
AttachmentFileSizeLimit int64 // Max file size per file (bytes)
AttachmentTotalSizeLimit int64 // Total file size for all files of this user (bytes)
AttachmentExpiryDuration time.Duration // Duration after which attachments will be deleted

View file

@ -27,8 +27,14 @@ type FixedLimiter struct {
// NewFixedLimiter creates a new Limiter
func NewFixedLimiter(limit int64) *FixedLimiter {
return NewFixedLimiterWithValue(limit, 0)
}
// NewFixedLimiterWithValue creates a new Limiter and sets the initial value
func NewFixedLimiterWithValue(limit, value int64) *FixedLimiter {
return &FixedLimiter{
limit: limit,
value: value,
}
}

View file

@ -17,7 +17,7 @@ var (
// NextOccurrenceUTC takes a time of day (e.g. 9:00am), and returns the next occurrence
// of that time from the current time (in UTC).
func NextOccurrenceUTC(timeOfDay, base time.Time) time.Time {
hour, minute, seconds := timeOfDay.Clock()
hour, minute, seconds := timeOfDay.UTC().Clock()
now := base.UTC()
next := time.Date(now.Year(), now.Month(), now.Day(), hour, minute, seconds, 0, time.UTC)
if next.Before(now) {

View file

@ -337,6 +337,17 @@ func Retry[T any](f func() (*T, error), after ...time.Duration) (t *T, err error
return nil, err
}
// MinMax returns value if it is between min and max, or either
// min or max if it is out of range
func MinMax[T int | int64](value, min, max T) T {
if value < min {
return min
} else if value > max {
return max
}
return value
}
// String turns a string into a pointer of a string
func String(v string) *string {
return &v