mirror of
https://github.com/binwiederhier/ntfy.git
synced 2024-11-22 03:13:33 +01:00
Rate limits make sense now!
This commit is contained in:
parent
a036814d98
commit
c874a641df
17 changed files with 365 additions and 205 deletions
|
@ -77,6 +77,7 @@ var flagsServe = append(
|
|||
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-request-limit-burst", Aliases: []string{"visitor_request_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_BURST"}, Value: server.DefaultVisitorRequestLimitBurst, Usage: "initial limit of requests per visitor"}),
|
||||
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-request-limit-replenish", Aliases: []string{"visitor_request_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_REPLENISH"}, Value: server.DefaultVisitorRequestLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "visitor-request-limit-exempt-hosts", Aliases: []string{"visitor_request_limit_exempt_hosts"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_EXEMPT_HOSTS"}, Value: "", Usage: "hostnames and/or IP addresses of hosts that will be exempt from the visitor request limit"}),
|
||||
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-message-daily-limit", Aliases: []string{"visitor_message_daily_limit"}, EnvVars: []string{"NTFY_VISITOR_MESSAGE_DAILY_LIMIT"}, Value: server.DefaultVisitorMessageDailyLimit, Usage: "max messages per visitor per day, derived from request limit if unset"}),
|
||||
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", Aliases: []string{"visitor_email_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}),
|
||||
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-email-limit-replenish", Aliases: []string{"visitor_email_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: server.DefaultVisitorEmailLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
|
||||
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"behind_proxy", "P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}),
|
||||
|
@ -150,6 +151,7 @@ func execServe(c *cli.Context) error {
|
|||
visitorRequestLimitBurst := c.Int("visitor-request-limit-burst")
|
||||
visitorRequestLimitReplenish := c.Duration("visitor-request-limit-replenish")
|
||||
visitorRequestLimitExemptHosts := util.SplitNoEmpty(c.String("visitor-request-limit-exempt-hosts"), ",")
|
||||
visitorMessageDailyLimit := c.Int("visitor-message-daily-limit")
|
||||
visitorEmailLimitBurst := c.Int("visitor-email-limit-burst")
|
||||
visitorEmailLimitReplenish := c.Duration("visitor-email-limit-replenish")
|
||||
behindProxy := c.Bool("behind-proxy")
|
||||
|
@ -289,6 +291,7 @@ func execServe(c *cli.Context) error {
|
|||
conf.VisitorRequestLimitBurst = visitorRequestLimitBurst
|
||||
conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish
|
||||
conf.VisitorRequestExemptIPAddrs = visitorRequestLimitExemptIPs
|
||||
conf.VisitorMessageDailyLimit = visitorMessageDailyLimit
|
||||
conf.VisitorEmailLimitBurst = visitorEmailLimitBurst
|
||||
conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish
|
||||
conf.BehindProxy = behindProxy
|
||||
|
|
|
@ -44,6 +44,7 @@ const (
|
|||
DefaultVisitorSubscriptionLimit = 30
|
||||
DefaultVisitorRequestLimitBurst = 60
|
||||
DefaultVisitorRequestLimitReplenish = 5 * time.Second
|
||||
DefaultVisitorMessageDailyLimit = 0
|
||||
DefaultVisitorEmailLimitBurst = 16
|
||||
DefaultVisitorEmailLimitReplenish = time.Hour
|
||||
DefaultVisitorAccountCreationLimitBurst = 3
|
||||
|
@ -105,6 +106,7 @@ type Config struct {
|
|||
VisitorRequestLimitBurst int
|
||||
VisitorRequestLimitReplenish time.Duration
|
||||
VisitorRequestExemptIPAddrs []netip.Prefix
|
||||
VisitorMessageDailyLimit int
|
||||
VisitorEmailLimitBurst int
|
||||
VisitorEmailLimitReplenish time.Duration
|
||||
VisitorAccountCreationLimitBurst int
|
||||
|
@ -171,6 +173,7 @@ func NewConfig() *Config {
|
|||
VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst,
|
||||
VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish,
|
||||
VisitorRequestExemptIPAddrs: make([]netip.Prefix, 0),
|
||||
VisitorMessageDailyLimit: DefaultVisitorMessageDailyLimit,
|
||||
VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst,
|
||||
VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish,
|
||||
VisitorAccountCreationLimitBurst: DefaultVisitorAccountCreationLimitBurst,
|
||||
|
|
|
@ -75,10 +75,10 @@ var (
|
|||
errHTTPTooManyRequestsLimitEmails = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
|
||||
errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
|
||||
errHTTPTooManyRequestsLimitTotalTopics = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"}
|
||||
errHTTPTooManyRequestsLimitAttachmentBandwidth = &errHTTP{42905, http.StatusTooManyRequests, "limit reached: daily bandwidth", "https://ntfy.sh/docs/publish/#limitations"}
|
||||
errHTTPTooManyRequestsLimitAttachmentBandwidth = &errHTTP{42905, http.StatusTooManyRequests, "limit reached: daily bandwidth reached", "https://ntfy.sh/docs/publish/#limitations"}
|
||||
errHTTPTooManyRequestsLimitAccountCreation = &errHTTP{42906, http.StatusTooManyRequests, "limit reached: too many accounts created", "https://ntfy.sh/docs/publish/#limitations"} // FIXME document limit
|
||||
errHTTPTooManyRequestsLimitReservations = &errHTTP{42907, http.StatusTooManyRequests, "limit reached: too many topic reservations for this user", ""}
|
||||
errHTTPTooManyRequestsLimitMessages = &errHTTP{42908, http.StatusTooManyRequests, "limit reached: too many messages", "https://ntfy.sh/docs/publish/#limitations"}
|
||||
errHTTPTooManyRequestsLimitMessages = &errHTTP{42908, http.StatusTooManyRequests, "limit reached: daily message quota reached", "https://ntfy.sh/docs/publish/#limitations"}
|
||||
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
|
||||
errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", ""}
|
||||
errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/"}
|
||||
|
|
|
@ -38,10 +38,9 @@ import (
|
|||
TODO
|
||||
--
|
||||
|
||||
- HIGH Rate limiting: dailyLimitToRate is wrong? + TESTS
|
||||
- HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
|
||||
- HIGH Rate limiting: Delete visitor when tier is changed to refresh rate limiters
|
||||
- HIGH Rate limiting: When ResetStats() is run, reset messagesLimiter (and others)?
|
||||
- MEDIUM Rate limiting: Test daily message quota read from database initially
|
||||
- MEDIUM: Races with v.user (see publishSyncEventAsync test)
|
||||
- MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
|
||||
- MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade)
|
||||
|
@ -57,7 +56,6 @@ Make sure account endpoints make sense for admins
|
|||
|
||||
Tests:
|
||||
- Payment endpoints (make mocks)
|
||||
- test that the visitor is based on the IP address when a user has no tier
|
||||
*/
|
||||
|
||||
// Server is the main server, providing the UI and API for ntfy
|
||||
|
@ -308,7 +306,7 @@ func (s *Server) Stop() {
|
|||
}
|
||||
|
||||
func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
|
||||
v, err := s.visitor(r) // Note: Always returns v, even when error is returned
|
||||
v, err := s.maybeAuthenticate(r) // Note: Always returns v, even when error is returned
|
||||
if err == nil {
|
||||
log.Debug("%s Dispatching request", logHTTPPrefix(v, r))
|
||||
if log.IsTrace() {
|
||||
|
@ -563,7 +561,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
|||
if v.user != nil {
|
||||
m.User = v.user.ID
|
||||
}
|
||||
m.Expires = time.Now().Add(v.Limits().MessagesExpiryDuration).Unix()
|
||||
m.Expires = time.Now().Add(v.Limits().MessageExpiryDuration).Unix()
|
||||
if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -601,7 +599,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
|||
}
|
||||
v.IncrementMessages()
|
||||
if s.userManager != nil && v.user != nil {
|
||||
s.userManager.EnqueueStats(v.user)
|
||||
s.userManager.EnqueueStats(v.user) // FIXME this makes no sense for tier-less users
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.messages++
|
||||
|
@ -1382,8 +1380,10 @@ func (s *Server) runStatsResetter() {
|
|||
log.Debug("Stats resetter: Waiting until %v to reset visitor stats", runAt)
|
||||
select {
|
||||
case <-timer.C:
|
||||
log.Debug("Stats resetter: Running")
|
||||
s.resetStats()
|
||||
case <-s.closeChan:
|
||||
log.Debug("Stats resetter: Stopping timer")
|
||||
timer.Stop()
|
||||
return
|
||||
}
|
||||
|
@ -1440,17 +1440,15 @@ func (s *Server) sendDelayedMessages() error {
|
|||
return err
|
||||
}
|
||||
for _, m := range messages {
|
||||
var v *visitor
|
||||
var u *user.User
|
||||
if s.userManager != nil && m.User != "" {
|
||||
u, err := s.userManager.User(m.User)
|
||||
u, err = s.userManager.User(m.User)
|
||||
if err != nil {
|
||||
log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
|
||||
log.Warn("Error sending delayed message %s: %s", m.ID, err.Error())
|
||||
continue
|
||||
}
|
||||
v = s.visitorFromUser(u, m.Sender)
|
||||
} else {
|
||||
v = s.visitorFromIP(m.Sender)
|
||||
}
|
||||
v := s.visitor(m.Sender, u)
|
||||
if err := s.sendDelayedMessage(v, m); err != nil {
|
||||
log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
|
||||
}
|
||||
|
@ -1588,20 +1586,16 @@ func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc
|
|||
}
|
||||
}
|
||||
|
||||
// visitor creates or retrieves a rate.Limiter for the given visitor.
|
||||
// maybeAuthenticate creates or retrieves a rate.Limiter for the given visitor.
|
||||
// Note that this function will always return a visitor, even if an error occurs.
|
||||
func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
|
||||
func (s *Server) maybeAuthenticate(r *http.Request) (v *visitor, err error) {
|
||||
ip := extractIPAddress(r, s.config.BehindProxy)
|
||||
var u *user.User // may stay nil if no auth header!
|
||||
if u, err = s.authenticate(r); err != nil {
|
||||
log.Debug("authentication failed: %s", err.Error())
|
||||
err = errHTTPUnauthorized // Always return visitor, even when error occurs!
|
||||
}
|
||||
if u != nil {
|
||||
v = s.visitorFromUser(u, ip)
|
||||
} else {
|
||||
v = s.visitorFromIP(ip)
|
||||
}
|
||||
v = s.visitor(ip, u)
|
||||
v.SetUser(u) // Update visitor user with latest from database!
|
||||
return v, err // Always return visitor, even when error occurs!
|
||||
}
|
||||
|
@ -1645,26 +1639,19 @@ func (s *Server) authenticateBearerAuth(value string) (user *user.User, err erro
|
|||
return s.userManager.AuthenticateToken(token)
|
||||
}
|
||||
|
||||
func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *user.User) *visitor {
|
||||
func (s *Server) visitor(ip netip.Addr, user *user.User) *visitor {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
v, exists := s.visitors[visitorID]
|
||||
id := visitorID(ip, user)
|
||||
v, exists := s.visitors[id]
|
||||
if !exists {
|
||||
s.visitors[visitorID] = newVisitor(s.config, s.messageCache, s.userManager, ip, user)
|
||||
return s.visitors[visitorID]
|
||||
s.visitors[id] = newVisitor(s.config, s.messageCache, s.userManager, ip, user)
|
||||
return s.visitors[id]
|
||||
}
|
||||
v.Keepalive()
|
||||
return v
|
||||
}
|
||||
|
||||
func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
|
||||
return s.visitorFromID(fmt.Sprintf("ip:%s", ip.String()), ip, nil)
|
||||
}
|
||||
|
||||
func (s *Server) visitorFromUser(user *user.User, ip netip.Addr) *visitor {
|
||||
return s.visitorFromID(fmt.Sprintf("user:%s", user.ID), ip, user)
|
||||
}
|
||||
|
||||
func (s *Server) writeJSON(w http.ResponseWriter, v any) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||
|
|
|
@ -200,6 +200,12 @@
|
|||
# visitor-request-limit-replenish: "5s"
|
||||
# visitor-request-limit-exempt-hosts: ""
|
||||
|
||||
# Rate limiting: Hard daily limit of messages per visitor and day. The limit is reset
|
||||
# every day at midnight UTC. If the limit is not set (or set to zero), the request
|
||||
# limit (see above) governs the upper limit.
|
||||
#
|
||||
# visitor-message-daily-limit: 0
|
||||
|
||||
# Rate limiting: Allowed emails per visitor:
|
||||
# - visitor-email-limit-burst is the initial bucket of emails each visitor has
|
||||
# - visitor-email-limit-replenish is the rate at which the bucket is refilled
|
||||
|
|
|
@ -23,6 +23,9 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
|
|||
} else if v.user != nil {
|
||||
return errHTTPUnauthorized // Cannot create account from user context
|
||||
}
|
||||
if err := v.AccountCreationAllowed(); err != nil {
|
||||
return errHTTPTooManyRequestsLimitAccountCreation
|
||||
}
|
||||
}
|
||||
newAccount, err := readJSONWithLimit[apiAccountCreateRequest](r.Body, jsonBodyBytesLimit)
|
||||
if err != nil {
|
||||
|
@ -31,9 +34,6 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
|
|||
if existingUser, _ := s.userManager.User(newAccount.Username); existingUser != nil {
|
||||
return errHTTPConflictUserExists
|
||||
}
|
||||
if err := v.AccountCreationAllowed(); err != nil {
|
||||
return errHTTPTooManyRequestsLimitAccountCreation
|
||||
}
|
||||
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { // TODO this should return a User
|
||||
return err
|
||||
}
|
||||
|
@ -49,9 +49,9 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis
|
|||
response := &apiAccountResponse{
|
||||
Limits: &apiAccountLimits{
|
||||
Basis: string(limits.Basis),
|
||||
Messages: limits.MessagesLimit,
|
||||
MessagesExpiryDuration: int64(limits.MessagesExpiryDuration.Seconds()),
|
||||
Emails: limits.EmailsLimit,
|
||||
Messages: limits.MessageLimit,
|
||||
MessagesExpiryDuration: int64(limits.MessageExpiryDuration.Seconds()),
|
||||
Emails: limits.EmailLimit,
|
||||
Reservations: limits.ReservationsLimit,
|
||||
AttachmentTotalSize: limits.AttachmentTotalSizeLimit,
|
||||
AttachmentFileSize: limits.AttachmentFileSizeLimit,
|
||||
|
@ -344,7 +344,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
|
|||
reservations, err := s.userManager.ReservationsCount(v.user.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if reservations >= v.user.Tier.ReservationsLimit {
|
||||
} else if reservations >= v.user.Tier.ReservationLimit {
|
||||
return errHTTPTooManyRequestsLimitReservations
|
||||
}
|
||||
}
|
||||
|
|
|
@ -410,10 +410,10 @@ func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) {
|
|||
// Create a tier
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessagesLimit: 123,
|
||||
MessagesExpiryDuration: 86400 * time.Second,
|
||||
EmailsLimit: 32,
|
||||
ReservationsLimit: 2,
|
||||
MessageLimit: 123,
|
||||
MessageExpiryDuration: 86400 * time.Second,
|
||||
EmailLimit: 32,
|
||||
ReservationLimit: 2,
|
||||
AttachmentFileSizeLimit: 1231231,
|
||||
AttachmentTotalSizeLimit: 123123,
|
||||
AttachmentExpiryDuration: 10800 * time.Second,
|
||||
|
@ -492,8 +492,8 @@ func TestAccount_Reservation_PublishByAnonymousFails(t *testing.T) {
|
|||
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessagesLimit: 20,
|
||||
ReservationsLimit: 2,
|
||||
MessageLimit: 20,
|
||||
ReservationLimit: 2,
|
||||
}))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
|
||||
|
@ -526,8 +526,8 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
|
|||
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessagesLimit: 20,
|
||||
ReservationsLimit: 2,
|
||||
MessageLimit: 20,
|
||||
ReservationLimit: 2,
|
||||
}))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
|
||||
|
@ -591,10 +591,10 @@ func TestAccount_Tier_Create(t *testing.T) {
|
|||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
Code: "pro",
|
||||
Name: "Pro",
|
||||
MessagesLimit: 123,
|
||||
MessagesExpiryDuration: 86400 * time.Second,
|
||||
EmailsLimit: 32,
|
||||
ReservationsLimit: 2,
|
||||
MessageLimit: 123,
|
||||
MessageExpiryDuration: 86400 * time.Second,
|
||||
EmailLimit: 32,
|
||||
ReservationLimit: 2,
|
||||
AttachmentFileSizeLimit: 1231231,
|
||||
AttachmentTotalSizeLimit: 123123,
|
||||
AttachmentExpiryDuration: 10800 * time.Second,
|
||||
|
@ -616,10 +616,10 @@ func TestAccount_Tier_Create(t *testing.T) {
|
|||
require.True(t, strings.HasPrefix(ti.ID, "ti_"))
|
||||
require.Equal(t, "pro", ti.Code)
|
||||
require.Equal(t, "Pro", ti.Name)
|
||||
require.Equal(t, int64(123), ti.MessagesLimit)
|
||||
require.Equal(t, 86400*time.Second, ti.MessagesExpiryDuration)
|
||||
require.Equal(t, int64(32), ti.EmailsLimit)
|
||||
require.Equal(t, int64(2), ti.ReservationsLimit)
|
||||
require.Equal(t, int64(123), ti.MessageLimit)
|
||||
require.Equal(t, 86400*time.Second, ti.MessageExpiryDuration)
|
||||
require.Equal(t, int64(32), ti.EmailLimit)
|
||||
require.Equal(t, int64(2), ti.ReservationLimit)
|
||||
require.Equal(t, int64(1231231), ti.AttachmentFileSizeLimit)
|
||||
require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit)
|
||||
require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration)
|
||||
|
|
|
@ -60,15 +60,15 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
freeTier := defaultVisitorLimits(s.config)
|
||||
freeTier := configBasedVisitorLimits(s.config)
|
||||
response := []*apiAccountBillingTier{
|
||||
{
|
||||
// This is a bit of a hack: This is the "Free" tier. It has no tier code, name or price.
|
||||
Limits: &apiAccountLimits{
|
||||
Basis: string(visitorLimitBasisIP),
|
||||
Messages: freeTier.MessagesLimit,
|
||||
MessagesExpiryDuration: int64(freeTier.MessagesExpiryDuration.Seconds()),
|
||||
Emails: freeTier.EmailsLimit,
|
||||
Messages: freeTier.MessageLimit,
|
||||
MessagesExpiryDuration: int64(freeTier.MessageExpiryDuration.Seconds()),
|
||||
Emails: freeTier.EmailLimit,
|
||||
Reservations: freeTier.ReservationsLimit,
|
||||
AttachmentTotalSize: freeTier.AttachmentTotalSizeLimit,
|
||||
AttachmentFileSize: freeTier.AttachmentFileSizeLimit,
|
||||
|
@ -91,10 +91,10 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
|
|||
Price: priceStr,
|
||||
Limits: &apiAccountLimits{
|
||||
Basis: string(visitorLimitBasisTier),
|
||||
Messages: tier.MessagesLimit,
|
||||
MessagesExpiryDuration: int64(tier.MessagesExpiryDuration.Seconds()),
|
||||
Emails: tier.EmailsLimit,
|
||||
Reservations: tier.ReservationsLimit,
|
||||
Messages: tier.MessageLimit,
|
||||
MessagesExpiryDuration: int64(tier.MessageExpiryDuration.Seconds()),
|
||||
Emails: tier.EmailLimit,
|
||||
Reservations: tier.ReservationLimit,
|
||||
AttachmentTotalSize: tier.AttachmentTotalSizeLimit,
|
||||
AttachmentFileSize: tier.AttachmentFileSizeLimit,
|
||||
AttachmentExpiryDuration: int64(tier.AttachmentExpiryDuration.Seconds()),
|
||||
|
@ -336,7 +336,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
|
|||
if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
|
||||
return err
|
||||
}
|
||||
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
|
||||
s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -355,14 +355,14 @@ func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMe
|
|||
if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, nil, ev.Customer, "", "", 0, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
|
||||
s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) updateSubscriptionAndTier(logPrefix string, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
|
||||
reservationsLimit := visitorDefaultReservationsLimit
|
||||
if tier != nil {
|
||||
reservationsLimit = tier.ReservationsLimit
|
||||
reservationsLimit = tier.ReservationLimit
|
||||
}
|
||||
if err := s.maybeRemoveMessagesAndExcessReservations(logPrefix, u, reservationsLimit); err != nil {
|
||||
return err
|
||||
|
|
|
@ -5,11 +5,14 @@ import (
|
|||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stripe/stripe-go/v74"
|
||||
"golang.org/x/time/rate"
|
||||
"heckel.io/ntfy/user"
|
||||
"heckel.io/ntfy/util"
|
||||
"io"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
@ -48,10 +51,10 @@ func TestPayments_Tiers(t *testing.T) {
|
|||
ID: "ti_123",
|
||||
Code: "pro",
|
||||
Name: "Pro",
|
||||
MessagesLimit: 1000,
|
||||
MessagesExpiryDuration: time.Hour,
|
||||
EmailsLimit: 123,
|
||||
ReservationsLimit: 777,
|
||||
MessageLimit: 1000,
|
||||
MessageExpiryDuration: time.Hour,
|
||||
EmailLimit: 123,
|
||||
ReservationLimit: 777,
|
||||
AttachmentFileSizeLimit: 999,
|
||||
AttachmentTotalSizeLimit: 888,
|
||||
AttachmentExpiryDuration: time.Minute,
|
||||
|
@ -61,10 +64,10 @@ func TestPayments_Tiers(t *testing.T) {
|
|||
ID: "ti_444",
|
||||
Code: "business",
|
||||
Name: "Business",
|
||||
MessagesLimit: 2000,
|
||||
MessagesExpiryDuration: 10 * time.Hour,
|
||||
EmailsLimit: 123123,
|
||||
ReservationsLimit: 777333,
|
||||
MessageLimit: 2000,
|
||||
MessageExpiryDuration: 10 * time.Hour,
|
||||
EmailLimit: 123123,
|
||||
ReservationLimit: 777333,
|
||||
AttachmentFileSizeLimit: 999111,
|
||||
AttachmentTotalSizeLimit: 888111,
|
||||
AttachmentExpiryDuration: time.Hour,
|
||||
|
@ -238,9 +241,14 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
|
|||
require.Equal(t, 401, rr.Code)
|
||||
}
|
||||
|
||||
func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *testing.T) {
|
||||
// This tests a successful checkout flow (not a paying customer -> paying customer),
|
||||
// and also tests that during the upgrade we are RESETTING THE RATE LIMITS of the existing user.
|
||||
func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *testing.T) {
|
||||
// This test is too overloaded, but it's also a great end-to-end a test.
|
||||
//
|
||||
// It tests:
|
||||
// - A successful checkout flow (not a paying customer -> paying customer)
|
||||
// - Tier-changes reset the rate limits for the user
|
||||
// - The request limits for tier-less user and a tier-user
|
||||
// - The message limits for a tier-user
|
||||
|
||||
stripeMock := &testStripeAPI{}
|
||||
defer stripeMock.AssertExpectations(t)
|
||||
|
@ -248,8 +256,15 @@ func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *test
|
|||
c := newTestConfigWithAuthFile(t)
|
||||
c.StripeSecretKey = "secret key"
|
||||
c.StripeWebhookKey = "webhook key"
|
||||
c.VisitorRequestLimitBurst = 10
|
||||
c.VisitorRequestLimitBurst = 5
|
||||
c.VisitorRequestLimitReplenish = time.Hour
|
||||
c.CacheStartupQueries = `
|
||||
pragma journal_mode = WAL;
|
||||
pragma synchronous = normal;
|
||||
pragma temp_store = memory;
|
||||
`
|
||||
c.CacheBatchSize = 500
|
||||
c.CacheBatchTimeout = time.Second
|
||||
s := newTestServer(t, c)
|
||||
s.stripe = stripeMock
|
||||
|
||||
|
@ -258,9 +273,9 @@ func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *test
|
|||
ID: "ti_123",
|
||||
Code: "starter",
|
||||
StripePriceID: "price_1234",
|
||||
ReservationsLimit: 1,
|
||||
MessagesLimit: 100,
|
||||
MessagesExpiryDuration: time.Hour,
|
||||
ReservationLimit: 1,
|
||||
MessageLimit: 220, // 220 * 5% = 11 requests before rate limiting kicks in
|
||||
MessageExpiryDuration: time.Hour,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) // No tier
|
||||
u, err := s.userManager.User("phil")
|
||||
|
@ -298,7 +313,7 @@ func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *test
|
|||
Return(&stripe.Customer{}, nil)
|
||||
|
||||
// Send messages until rate limit of free tier is hit
|
||||
for i := 0; i < 10; i++ {
|
||||
for i := 0; i < 5; i++ {
|
||||
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
|
@ -323,10 +338,9 @@ func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *test
|
|||
require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix())
|
||||
require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
|
||||
|
||||
// FIXME FIXME This test is broken, because the rate limit logic is unclear!
|
||||
|
||||
// Now for the fun part: Verify that new rate limits are immediately applied
|
||||
for i := 0; i < 100; i++ {
|
||||
// This only tests the request limiter, which kicks in before the message limiter.
|
||||
for i := 0; i < 11; i++ {
|
||||
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
|
@ -336,6 +350,37 @@ func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *test
|
|||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 429, rr.Code)
|
||||
|
||||
// Now let's test the message limiter by faking a ridiculously generous rate limiter
|
||||
v := s.visitor(netip.MustParseAddr("9.9.9.9"), u)
|
||||
v.requestLimiter = rate.NewLimiter(rate.Every(time.Millisecond), 1000000)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 209; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 429, rr.Code)
|
||||
|
||||
// And now let's cross-check that the stats are correct too
|
||||
rr = request(t, s, "GET", "/v1/account", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
account, _ := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body))
|
||||
require.Equal(t, int64(220), account.Limits.Messages)
|
||||
require.Equal(t, int64(220), account.Stats.Messages)
|
||||
require.Equal(t, int64(0), account.Stats.MessagesRemaining)
|
||||
}
|
||||
|
||||
func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
|
||||
|
@ -363,9 +408,9 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
|
|||
ID: "ti_1",
|
||||
Code: "starter",
|
||||
StripePriceID: "price_1234", // !
|
||||
ReservationsLimit: 1, // !
|
||||
MessagesLimit: 100,
|
||||
MessagesExpiryDuration: time.Hour,
|
||||
ReservationLimit: 1, // !
|
||||
MessageLimit: 100,
|
||||
MessageExpiryDuration: time.Hour,
|
||||
AttachmentExpiryDuration: time.Hour,
|
||||
AttachmentFileSizeLimit: 1000000,
|
||||
AttachmentTotalSizeLimit: 1000000,
|
||||
|
@ -375,9 +420,9 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
|
|||
ID: "ti_2",
|
||||
Code: "pro",
|
||||
StripePriceID: "price_1111", // !
|
||||
ReservationsLimit: 3, // !
|
||||
MessagesLimit: 200,
|
||||
MessagesExpiryDuration: time.Hour,
|
||||
ReservationLimit: 3, // !
|
||||
MessageLimit: 200,
|
||||
MessageExpiryDuration: time.Hour,
|
||||
AttachmentExpiryDuration: time.Hour,
|
||||
AttachmentFileSizeLimit: 1000000,
|
||||
AttachmentTotalSizeLimit: 1000000,
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
"fmt"
|
||||
"heckel.io/ntfy/user"
|
||||
"io"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
@ -22,9 +21,14 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/util"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// log.SetLevel(log.DebugLevel)
|
||||
}
|
||||
|
||||
func TestServer_PublishAndPoll(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfig(t))
|
||||
|
||||
|
@ -742,16 +746,31 @@ func TestServer_Auth_ViaQuery(t *testing.T) {
|
|||
require.Equal(t, 401, response.Code)
|
||||
}
|
||||
|
||||
func TestServer_StatsResetter(t *testing.T) {
|
||||
func TestServer_StatsResetter_User_Without_Tier(t *testing.T) {
|
||||
// This tests the stats resetter for
|
||||
// - an anonymous user
|
||||
// - a user without a tier (treated like the same as the anonymous user)
|
||||
// - a user with a tier
|
||||
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.AuthDefault = user.PermissionDenyAll
|
||||
c.VisitorStatsResetTime = time.Now().Add(2 * time.Second)
|
||||
s := newTestServer(t, c)
|
||||
go s.runStatsResetter()
|
||||
|
||||
// Create user with tier (tieruser) and user without tier (phil)
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
Code: "test",
|
||||
MessageLimit: 5,
|
||||
MessageExpiryDuration: -5 * time.Second, // Second, what a hack!
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
require.Nil(t, s.userManager.AllowAccess("phil", "mytopic", user.PermissionReadWrite))
|
||||
require.Nil(t, s.userManager.AddUser("tieruser", "tieruser", user.RoleUser))
|
||||
require.Nil(t, s.userManager.ChangeTier("tieruser", "test"))
|
||||
|
||||
// Send an anonymous message
|
||||
response := request(t, s, "PUT", "/mytopic", "test", nil)
|
||||
|
||||
// Send messages from user without tier (phil)
|
||||
for i := 0; i < 5; i++ {
|
||||
response := request(t, s, "PUT", "/mytopic", "test", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
|
@ -759,30 +778,66 @@ func TestServer_StatsResetter(t *testing.T) {
|
|||
require.Equal(t, 200, response.Code)
|
||||
}
|
||||
|
||||
response := request(t, s, "GET", "/v1/account", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
// Send messages from user with tier
|
||||
for i := 0; i < 2; i++ {
|
||||
response := request(t, s, "PUT", "/mytopic", "test", map[string]string{
|
||||
"Authorization": util.BasicAuth("tieruser", "tieruser"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
}
|
||||
|
||||
// User stats show 10 messages
|
||||
// User stats show 6 messages (for user without tier)
|
||||
response = request(t, s, "GET", "/v1/account", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
account, err := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(5), account.Stats.Messages)
|
||||
require.Equal(t, int64(6), account.Stats.Messages)
|
||||
|
||||
// User stats show 6 messages (for anonymous visitor)
|
||||
response = request(t, s, "GET", "/v1/account", "", nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(6), account.Stats.Messages)
|
||||
|
||||
// User stats show 2 messages (for user with tier)
|
||||
response = request(t, s, "GET", "/v1/account", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("tieruser", "tieruser"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(2), account.Stats.Messages)
|
||||
|
||||
// Wait for stats resetter to run
|
||||
time.Sleep(2200 * time.Millisecond)
|
||||
|
||||
// User stats show 0 messages now!
|
||||
response = request(t, s, "GET", "/v1/account", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(0), account.Stats.Messages)
|
||||
|
||||
// Since this is a user without a tier, the anonymous user should have the same stats
|
||||
response = request(t, s, "GET", "/v1/account", "", nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(0), account.Stats.Messages)
|
||||
|
||||
// User stats show 0 messages (for user with tier)
|
||||
response = request(t, s, "GET", "/v1/account", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("tieruser", "tieruser"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(0), account.Stats.Messages)
|
||||
}
|
||||
|
||||
type testMailer struct {
|
||||
|
@ -1134,8 +1189,8 @@ func TestServer_PublishWithTierBasedMessageLimitAndExpiry(t *testing.T) {
|
|||
// Create tier with certain limits
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
Code: "test",
|
||||
MessagesLimit: 5,
|
||||
MessagesExpiryDuration: -5 * time.Second, // Second, what a hack!
|
||||
MessageLimit: 5,
|
||||
MessageExpiryDuration: -5 * time.Second, // Second, what a hack!
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
|
||||
|
@ -1363,8 +1418,8 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) {
|
|||
sevenDays := time.Duration(604800) * time.Second
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
Code: "test",
|
||||
MessagesLimit: 10,
|
||||
MessagesExpiryDuration: sevenDays,
|
||||
MessageLimit: 10,
|
||||
MessageExpiryDuration: sevenDays,
|
||||
AttachmentFileSizeLimit: 50_000,
|
||||
AttachmentTotalSizeLimit: 200_000,
|
||||
AttachmentExpiryDuration: sevenDays, // 7 days
|
||||
|
@ -1407,8 +1462,8 @@ func TestServer_PublishAttachmentWithTierBasedBandwidthLimit(t *testing.T) {
|
|||
// Create tier with certain limits
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
Code: "test",
|
||||
MessagesLimit: 10,
|
||||
MessagesExpiryDuration: time.Hour,
|
||||
MessageLimit: 10,
|
||||
MessageExpiryDuration: time.Hour,
|
||||
AttachmentFileSizeLimit: 50_000,
|
||||
AttachmentTotalSizeLimit: 200_000,
|
||||
AttachmentExpiryDuration: time.Hour,
|
||||
|
@ -1450,7 +1505,7 @@ func TestServer_PublishAttachmentWithTierBasedLimits(t *testing.T) {
|
|||
// Create tier with certain limits
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
Code: "test",
|
||||
MessagesLimit: 100,
|
||||
MessageLimit: 100,
|
||||
AttachmentFileSizeLimit: 50_000,
|
||||
AttachmentTotalSizeLimit: 200_000,
|
||||
AttachmentExpiryDuration: 30 * time.Second,
|
||||
|
@ -1574,7 +1629,7 @@ func TestServer_Visitor_XForwardedFor_None(t *testing.T) {
|
|||
r, _ := http.NewRequest("GET", "/bla", nil)
|
||||
r.RemoteAddr = "8.9.10.11"
|
||||
r.Header.Set("X-Forwarded-For", " ") // Spaces, not empty!
|
||||
v, err := s.visitor(r)
|
||||
v, err := s.maybeAuthenticate(r)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "8.9.10.11", v.ip.String())
|
||||
}
|
||||
|
@ -1586,7 +1641,7 @@ func TestServer_Visitor_XForwardedFor_Single(t *testing.T) {
|
|||
r, _ := http.NewRequest("GET", "/bla", nil)
|
||||
r.RemoteAddr = "8.9.10.11"
|
||||
r.Header.Set("X-Forwarded-For", "1.1.1.1")
|
||||
v, err := s.visitor(r)
|
||||
v, err := s.maybeAuthenticate(r)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "1.1.1.1", v.ip.String())
|
||||
}
|
||||
|
@ -1598,7 +1653,7 @@ func TestServer_Visitor_XForwardedFor_Multiple(t *testing.T) {
|
|||
r, _ := http.NewRequest("GET", "/bla", nil)
|
||||
r.RemoteAddr = "8.9.10.11"
|
||||
r.Header.Set("X-Forwarded-For", "1.2.3.4 , 2.4.4.2,234.5.2.1 ")
|
||||
v, err := s.visitor(r)
|
||||
v, err := s.maybeAuthenticate(r)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "234.5.2.1", v.ip.String())
|
||||
}
|
||||
|
@ -1611,7 +1666,7 @@ func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) {
|
|||
s := newTestServer(t, c)
|
||||
|
||||
// Add lots of messages
|
||||
log.Printf("Adding %d messages", count)
|
||||
log.Info("Adding %d messages", count)
|
||||
start := time.Now()
|
||||
messages := make([]*message, 0)
|
||||
for i := 0; i < count; i++ {
|
||||
|
@ -1621,31 +1676,31 @@ func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) {
|
|||
messages = append(messages, newDefaultMessage(topicID, "some message"))
|
||||
}
|
||||
require.Nil(t, s.messageCache.addMessages(messages))
|
||||
log.Printf("Done: Adding %d messages; took %s", count, time.Since(start).Round(time.Millisecond))
|
||||
log.Info("Done: Adding %d messages; took %s", count, time.Since(start).Round(time.Millisecond))
|
||||
|
||||
// Update stats
|
||||
statsChan := make(chan bool)
|
||||
go func() {
|
||||
log.Printf("Updating stats")
|
||||
log.Info("Updating stats")
|
||||
start := time.Now()
|
||||
s.execManager()
|
||||
log.Printf("Done: Updating stats; took %s", time.Since(start).Round(time.Millisecond))
|
||||
log.Info("Done: Updating stats; took %s", time.Since(start).Round(time.Millisecond))
|
||||
statsChan <- true
|
||||
}()
|
||||
time.Sleep(50 * time.Millisecond) // Make sure it starts first
|
||||
|
||||
// Publish message (during stats update)
|
||||
log.Printf("Publishing message")
|
||||
log.Info("Publishing message")
|
||||
start = time.Now()
|
||||
response := request(t, s, "PUT", "/mytopic", "some body", nil)
|
||||
m := toMessage(t, response.Body.String())
|
||||
assert.Equal(t, "some body", m.Message)
|
||||
assert.True(t, time.Since(start) < 100*time.Millisecond)
|
||||
log.Printf("Done: Publishing message; took %s", time.Since(start).Round(time.Millisecond))
|
||||
log.Info("Done: Publishing message; took %s", time.Since(start).Round(time.Millisecond))
|
||||
|
||||
// Wait for all goroutines
|
||||
<-statsChan
|
||||
log.Printf("Done: Waiting for all locks")
|
||||
log.Info("Done: Waiting for all locks")
|
||||
}
|
||||
|
||||
func newTestConfig(t *testing.T) *Config {
|
||||
|
|
|
@ -14,16 +14,39 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
// oneDay is an approximation of a day as a time.Duration
|
||||
oneDay = 24 * time.Hour
|
||||
|
||||
// visitorExpungeAfter defines how long a visitor is active before it is removed from memory. This number
|
||||
// has to be very high to prevent e-mail abuse, but it doesn't really affect the other limits anyway, since
|
||||
// they are replenished faster (typically).
|
||||
visitorExpungeAfter = 24 * time.Hour
|
||||
visitorExpungeAfter = oneDay
|
||||
|
||||
// visitorDefaultReservationsLimit is the amount of topic names a user without a tier is allowed to reserve.
|
||||
// This number is zero, and changing it may have unintended consequences in the web app, or otherwise
|
||||
visitorDefaultReservationsLimit = int64(0)
|
||||
)
|
||||
|
||||
// Constants used to convert a tier-user's MessageLimit (see user.Tier) into adequate request limiter
|
||||
// values (token bucket).
|
||||
//
|
||||
// Example: Assuming a user.Tier's MessageLimit is 10,000:
|
||||
// - the allowed burst is 500 (= 10,000 * 5%), which is < 1000 (the max)
|
||||
// - the replenish rate is 2 * 10,000 / 24 hours
|
||||
const (
|
||||
visitorMessageToRequestLimitBurstRate = 0.05
|
||||
visitorMessageToRequestLimitBurstMax = 1000
|
||||
visitorMessageToRequestLimitReplenishFactor = 2
|
||||
)
|
||||
|
||||
// Constants used to convert a tier-user's EmailLimit (see user.Tier) into adequate email limiter
|
||||
// values (token bucket). Example: Assuming a user.Tier's EmailLimit is 200, the allowed burst is
|
||||
// 40 (= 200 * 20%), which is <150 (the max).
|
||||
const (
|
||||
visitorEmailLimitBurstRate = 0.2
|
||||
visitorEmailLimitBurstMax = 150
|
||||
)
|
||||
|
||||
var (
|
||||
errVisitorLimitReached = errors.New("limit reached")
|
||||
)
|
||||
|
@ -55,9 +78,13 @@ type visitorInfo struct {
|
|||
|
||||
type visitorLimits struct {
|
||||
Basis visitorLimitBasis
|
||||
MessagesLimit int64
|
||||
MessagesExpiryDuration time.Duration
|
||||
EmailsLimit int64
|
||||
RequestLimitBurst int
|
||||
RequestLimitReplenish rate.Limit
|
||||
MessageLimit int64
|
||||
MessageExpiryDuration time.Duration
|
||||
EmailLimit int64
|
||||
EmailLimitBurst int
|
||||
EmailLimitReplenish rate.Limit
|
||||
ReservationsLimit int64
|
||||
AttachmentTotalSizeLimit int64
|
||||
AttachmentFileSizeLimit int64
|
||||
|
@ -173,7 +200,7 @@ func (v *visitor) SubscriptionAllowed() error {
|
|||
}
|
||||
|
||||
func (v *visitor) AccountCreationAllowed() error {
|
||||
if v.accountLimiter != nil && !v.accountLimiter.Allow() {
|
||||
if v.accountLimiter == nil || (v.accountLimiter != nil && !v.accountLimiter.Allow()) {
|
||||
return errVisitorLimitReached
|
||||
}
|
||||
return nil
|
||||
|
@ -242,31 +269,6 @@ func (v *visitor) SetUser(u *user.User) {
|
|||
}
|
||||
}
|
||||
|
||||
func (v *visitor) resetLimiters() {
|
||||
log.Info("%s Resetting limiters for visitor", v.stringNoLock())
|
||||
var messagesLimiter, bandwidthLimiter util.Limiter
|
||||
var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter
|
||||
if v.user != nil && v.user.Tier != nil {
|
||||
requestLimiter = rate.NewLimiter(dailyLimitToRate(v.user.Tier.MessagesLimit), v.config.VisitorRequestLimitBurst)
|
||||
messagesLimiter = util.NewFixedLimiter(v.user.Tier.MessagesLimit)
|
||||
emailsLimiter = rate.NewLimiter(dailyLimitToRate(v.user.Tier.EmailsLimit), v.config.VisitorEmailLimitBurst)
|
||||
bandwidthLimiter = util.NewBytesLimiter(int(v.user.Tier.AttachmentBandwidthLimit), 24*time.Hour)
|
||||
} else {
|
||||
requestLimiter = rate.NewLimiter(rate.Every(v.config.VisitorRequestLimitReplenish), v.config.VisitorRequestLimitBurst)
|
||||
messagesLimiter = nil // Message limit is governed by the requestLimiter
|
||||
emailsLimiter = rate.NewLimiter(rate.Every(v.config.VisitorEmailLimitReplenish), v.config.VisitorEmailLimitBurst)
|
||||
bandwidthLimiter = util.NewBytesLimiter(int(v.config.VisitorAttachmentDailyBandwidthLimit), 24*time.Hour)
|
||||
}
|
||||
if v.user == nil {
|
||||
accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst)
|
||||
}
|
||||
v.requestLimiter = requestLimiter
|
||||
v.messagesLimiter = messagesLimiter
|
||||
v.emailsLimiter = emailsLimiter
|
||||
v.bandwidthLimiter = bandwidthLimiter
|
||||
v.accountLimiter = accountLimiter
|
||||
}
|
||||
|
||||
// MaybeUserID returns the user ID of the visitor (if any). If this is an anonymous visitor,
|
||||
// an empty string is returned.
|
||||
func (v *visitor) MaybeUserID() string {
|
||||
|
@ -278,22 +280,71 @@ func (v *visitor) MaybeUserID() string {
|
|||
return ""
|
||||
}
|
||||
|
||||
func (v *visitor) resetLimiters() {
|
||||
log.Debug("%s Resetting limiters for visitor", v.stringNoLock())
|
||||
limits := v.limitsNoLock()
|
||||
v.requestLimiter = rate.NewLimiter(limits.RequestLimitReplenish, limits.RequestLimitBurst)
|
||||
v.messagesLimiter = util.NewFixedLimiterWithValue(limits.MessageLimit, v.messages)
|
||||
v.emailsLimiter = rate.NewLimiter(limits.EmailLimitReplenish, limits.EmailLimitBurst)
|
||||
v.bandwidthLimiter = util.NewBytesLimiter(int(limits.AttachmentBandwidthLimit), oneDay)
|
||||
if v.user == nil {
|
||||
v.accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst)
|
||||
} else {
|
||||
v.accountLimiter = nil // Users cannot create accounts when logged in
|
||||
}
|
||||
}
|
||||
|
||||
func (v *visitor) Limits() *visitorLimits {
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
limits := defaultVisitorLimits(v.config)
|
||||
return v.limitsNoLock()
|
||||
}
|
||||
|
||||
func (v *visitor) limitsNoLock() *visitorLimits {
|
||||
if v.user != nil && v.user.Tier != nil {
|
||||
limits.Basis = visitorLimitBasisTier
|
||||
limits.MessagesLimit = v.user.Tier.MessagesLimit
|
||||
limits.MessagesExpiryDuration = v.user.Tier.MessagesExpiryDuration
|
||||
limits.EmailsLimit = v.user.Tier.EmailsLimit
|
||||
limits.ReservationsLimit = v.user.Tier.ReservationsLimit
|
||||
limits.AttachmentTotalSizeLimit = v.user.Tier.AttachmentTotalSizeLimit
|
||||
limits.AttachmentFileSizeLimit = v.user.Tier.AttachmentFileSizeLimit
|
||||
limits.AttachmentExpiryDuration = v.user.Tier.AttachmentExpiryDuration
|
||||
limits.AttachmentBandwidthLimit = v.user.Tier.AttachmentBandwidthLimit
|
||||
return tierBasedVisitorLimits(v.config, v.user.Tier)
|
||||
}
|
||||
return configBasedVisitorLimits(v.config)
|
||||
}
|
||||
|
||||
func tierBasedVisitorLimits(conf *Config, tier *user.Tier) *visitorLimits {
|
||||
return &visitorLimits{
|
||||
Basis: visitorLimitBasisTier,
|
||||
RequestLimitBurst: util.MinMax(int(float64(tier.MessageLimit)*visitorMessageToRequestLimitBurstRate), conf.VisitorRequestLimitBurst, visitorMessageToRequestLimitBurstMax),
|
||||
RequestLimitReplenish: dailyLimitToRate(tier.MessageLimit * visitorMessageToRequestLimitReplenishFactor),
|
||||
MessageLimit: tier.MessageLimit,
|
||||
MessageExpiryDuration: tier.MessageExpiryDuration,
|
||||
EmailLimit: tier.EmailLimit,
|
||||
EmailLimitBurst: util.MinMax(int(float64(tier.EmailLimit)*visitorEmailLimitBurstRate), conf.VisitorEmailLimitBurst, visitorEmailLimitBurstMax),
|
||||
EmailLimitReplenish: dailyLimitToRate(tier.EmailLimit),
|
||||
ReservationsLimit: tier.ReservationLimit,
|
||||
AttachmentTotalSizeLimit: tier.AttachmentTotalSizeLimit,
|
||||
AttachmentFileSizeLimit: tier.AttachmentFileSizeLimit,
|
||||
AttachmentExpiryDuration: tier.AttachmentExpiryDuration,
|
||||
AttachmentBandwidthLimit: tier.AttachmentBandwidthLimit,
|
||||
}
|
||||
}
|
||||
|
||||
func configBasedVisitorLimits(conf *Config) *visitorLimits {
|
||||
messagesLimit := replenishDurationToDailyLimit(conf.VisitorRequestLimitReplenish) // Approximation!
|
||||
if conf.VisitorMessageDailyLimit > 0 {
|
||||
messagesLimit = int64(conf.VisitorMessageDailyLimit)
|
||||
}
|
||||
return &visitorLimits{
|
||||
Basis: visitorLimitBasisIP,
|
||||
RequestLimitBurst: conf.VisitorRequestLimitBurst,
|
||||
RequestLimitReplenish: rate.Every(conf.VisitorRequestLimitReplenish),
|
||||
MessageLimit: messagesLimit,
|
||||
MessageExpiryDuration: conf.CacheDuration,
|
||||
EmailLimit: replenishDurationToDailyLimit(conf.VisitorEmailLimitReplenish), // Approximation!
|
||||
EmailLimitBurst: conf.VisitorEmailLimitBurst,
|
||||
EmailLimitReplenish: rate.Every(conf.VisitorEmailLimitReplenish),
|
||||
ReservationsLimit: visitorDefaultReservationsLimit,
|
||||
AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit,
|
||||
AttachmentFileSizeLimit: conf.AttachmentFileSizeLimit,
|
||||
AttachmentExpiryDuration: conf.AttachmentExpiryDuration,
|
||||
AttachmentBandwidthLimit: conf.VisitorAttachmentDailyBandwidthLimit,
|
||||
}
|
||||
return limits
|
||||
}
|
||||
|
||||
func (v *visitor) Info() (*visitorInfo, error) {
|
||||
|
@ -321,9 +372,9 @@ func (v *visitor) Info() (*visitorInfo, error) {
|
|||
limits := v.Limits()
|
||||
stats := &visitorStats{
|
||||
Messages: messages,
|
||||
MessagesRemaining: zeroIfNegative(limits.MessagesLimit - messages),
|
||||
MessagesRemaining: zeroIfNegative(limits.MessageLimit - messages),
|
||||
Emails: emails,
|
||||
EmailsRemaining: zeroIfNegative(limits.EmailsLimit - emails),
|
||||
EmailsRemaining: zeroIfNegative(limits.EmailLimit - emails),
|
||||
Reservations: reservations,
|
||||
ReservationsRemaining: zeroIfNegative(limits.ReservationsLimit - reservations),
|
||||
AttachmentTotalSize: attachmentsBytesUsed,
|
||||
|
@ -343,23 +394,16 @@ func zeroIfNegative(value int64) int64 {
|
|||
}
|
||||
|
||||
func replenishDurationToDailyLimit(duration time.Duration) int64 {
|
||||
return int64(24 * time.Hour / duration)
|
||||
return int64(oneDay / duration)
|
||||
}
|
||||
|
||||
func dailyLimitToRate(limit int64) rate.Limit {
|
||||
return rate.Limit(limit) * rate.Every(24*time.Hour)
|
||||
return rate.Limit(limit) * rate.Every(oneDay)
|
||||
}
|
||||
|
||||
func defaultVisitorLimits(conf *Config) *visitorLimits {
|
||||
return &visitorLimits{
|
||||
Basis: visitorLimitBasisIP,
|
||||
MessagesLimit: replenishDurationToDailyLimit(conf.VisitorRequestLimitReplenish),
|
||||
MessagesExpiryDuration: conf.CacheDuration,
|
||||
EmailsLimit: replenishDurationToDailyLimit(conf.VisitorEmailLimitReplenish),
|
||||
ReservationsLimit: visitorDefaultReservationsLimit,
|
||||
AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit,
|
||||
AttachmentFileSizeLimit: conf.AttachmentFileSizeLimit,
|
||||
AttachmentExpiryDuration: conf.AttachmentExpiryDuration,
|
||||
AttachmentBandwidthLimit: conf.VisitorAttachmentDailyBandwidthLimit,
|
||||
func visitorID(ip netip.Addr, u *user.User) string {
|
||||
if u != nil && u.Tier != nil {
|
||||
return fmt.Sprintf("user:%s", u.ID)
|
||||
}
|
||||
return fmt.Sprintf("ip:%s", ip.String())
|
||||
}
|
||||
|
|
|
@ -709,10 +709,10 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
|
|||
ID: tierID.String,
|
||||
Code: tierCode.String,
|
||||
Name: tierName.String,
|
||||
MessagesLimit: messagesLimit.Int64,
|
||||
MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
|
||||
EmailsLimit: emailsLimit.Int64,
|
||||
ReservationsLimit: reservationsLimit.Int64,
|
||||
MessageLimit: messagesLimit.Int64,
|
||||
MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
|
||||
EmailLimit: emailsLimit.Int64,
|
||||
ReservationLimit: reservationsLimit.Int64,
|
||||
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
|
||||
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
|
||||
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
|
||||
|
@ -845,7 +845,7 @@ func (a *Manager) ChangeTier(username, tier string) error {
|
|||
t, err := a.Tier(tier)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if err := a.checkReservationsLimit(username, t.ReservationsLimit); err != nil {
|
||||
} else if err := a.checkReservationsLimit(username, t.ReservationLimit); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := a.db.Exec(updateUserTierQuery, tier, username); err != nil {
|
||||
|
@ -870,7 +870,7 @@ func (a *Manager) checkReservationsLimit(username string, reservationsLimit int6
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if u.Tier != nil && reservationsLimit < u.Tier.ReservationsLimit {
|
||||
if u.Tier != nil && reservationsLimit < u.Tier.ReservationLimit {
|
||||
reservations, err := a.Reservations(username)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -999,7 +999,7 @@ func (a *Manager) CreateTier(tier *Tier) error {
|
|||
if tier.ID == "" {
|
||||
tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
|
||||
}
|
||||
if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil {
|
||||
if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
@ -1070,10 +1070,10 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
|
|||
ID: id,
|
||||
Code: code,
|
||||
Name: name,
|
||||
MessagesLimit: messagesLimit.Int64,
|
||||
MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
|
||||
EmailsLimit: emailsLimit.Int64,
|
||||
ReservationsLimit: reservationsLimit.Int64,
|
||||
MessageLimit: messagesLimit.Int64,
|
||||
MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
|
||||
EmailLimit: emailsLimit.Int64,
|
||||
ReservationLimit: reservationsLimit.Int64,
|
||||
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
|
||||
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
|
||||
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
|
||||
|
|
|
@ -335,10 +335,10 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
|
|||
Code: "pro",
|
||||
Name: "ntfy Pro",
|
||||
StripePriceID: "price123",
|
||||
MessagesLimit: 5_000,
|
||||
MessagesExpiryDuration: 3 * 24 * time.Hour,
|
||||
EmailsLimit: 50,
|
||||
ReservationsLimit: 5,
|
||||
MessageLimit: 5_000,
|
||||
MessageExpiryDuration: 3 * 24 * time.Hour,
|
||||
EmailLimit: 50,
|
||||
ReservationLimit: 5,
|
||||
AttachmentFileSizeLimit: 52428800,
|
||||
AttachmentTotalSizeLimit: 524288000,
|
||||
AttachmentExpiryDuration: 24 * time.Hour,
|
||||
|
@ -351,10 +351,10 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
|
|||
require.Nil(t, err)
|
||||
require.Equal(t, RoleUser, ben.Role)
|
||||
require.Equal(t, "pro", ben.Tier.Code)
|
||||
require.Equal(t, int64(5000), ben.Tier.MessagesLimit)
|
||||
require.Equal(t, 3*24*time.Hour, ben.Tier.MessagesExpiryDuration)
|
||||
require.Equal(t, int64(50), ben.Tier.EmailsLimit)
|
||||
require.Equal(t, int64(5), ben.Tier.ReservationsLimit)
|
||||
require.Equal(t, int64(5000), ben.Tier.MessageLimit)
|
||||
require.Equal(t, 3*24*time.Hour, ben.Tier.MessageExpiryDuration)
|
||||
require.Equal(t, int64(50), ben.Tier.EmailLimit)
|
||||
require.Equal(t, int64(5), ben.Tier.ReservationLimit)
|
||||
require.Equal(t, int64(52428800), ben.Tier.AttachmentFileSizeLimit)
|
||||
require.Equal(t, int64(524288000), ben.Tier.AttachmentTotalSizeLimit)
|
||||
require.Equal(t, 24*time.Hour, ben.Tier.AttachmentExpiryDuration)
|
||||
|
|
|
@ -62,10 +62,10 @@ type Tier struct {
|
|||
ID string // Tier identifier (ti_...)
|
||||
Code string // Code of the tier
|
||||
Name string // Name of the tier
|
||||
MessagesLimit int64 // Daily message limit
|
||||
MessagesExpiryDuration time.Duration // Cache duration for messages
|
||||
EmailsLimit int64 // Daily email limit
|
||||
ReservationsLimit int64 // Number of topic reservations allowed by user
|
||||
MessageLimit int64 // Daily message limit
|
||||
MessageExpiryDuration time.Duration // Cache duration for messages
|
||||
EmailLimit int64 // Daily email limit
|
||||
ReservationLimit int64 // Number of topic reservations allowed by user
|
||||
AttachmentFileSizeLimit int64 // Max file size per file (bytes)
|
||||
AttachmentTotalSizeLimit int64 // Total file size for all files of this user (bytes)
|
||||
AttachmentExpiryDuration time.Duration // Duration after which attachments will be deleted
|
||||
|
|
|
@ -27,8 +27,14 @@ type FixedLimiter struct {
|
|||
|
||||
// NewFixedLimiter creates a new Limiter
|
||||
func NewFixedLimiter(limit int64) *FixedLimiter {
|
||||
return NewFixedLimiterWithValue(limit, 0)
|
||||
}
|
||||
|
||||
// NewFixedLimiterWithValue creates a new Limiter and sets the initial value
|
||||
func NewFixedLimiterWithValue(limit, value int64) *FixedLimiter {
|
||||
return &FixedLimiter{
|
||||
limit: limit,
|
||||
value: value,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ var (
|
|||
// NextOccurrenceUTC takes a time of day (e.g. 9:00am), and returns the next occurrence
|
||||
// of that time from the current time (in UTC).
|
||||
func NextOccurrenceUTC(timeOfDay, base time.Time) time.Time {
|
||||
hour, minute, seconds := timeOfDay.Clock()
|
||||
hour, minute, seconds := timeOfDay.UTC().Clock()
|
||||
now := base.UTC()
|
||||
next := time.Date(now.Year(), now.Month(), now.Day(), hour, minute, seconds, 0, time.UTC)
|
||||
if next.Before(now) {
|
||||
|
|
11
util/util.go
11
util/util.go
|
@ -337,6 +337,17 @@ func Retry[T any](f func() (*T, error), after ...time.Duration) (t *T, err error
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// MinMax returns value if it is between min and max, or either
|
||||
// min or max if it is out of range
|
||||
func MinMax[T int | int64](value, min, max T) T {
|
||||
if value < min {
|
||||
return min
|
||||
} else if value > max {
|
||||
return max
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// String turns a string into a pointer of a string
|
||||
func String(v string) *string {
|
||||
return &v
|
||||
|
|
Loading…
Reference in a new issue