1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2025-06-24 13:28:19 +02:00

Rate limits make sense now!

This commit is contained in:
binwiederhier 2023-01-26 22:57:18 -05:00
parent a036814d98
commit c874a641df
17 changed files with 365 additions and 205 deletions
server

View file

@ -38,10 +38,9 @@ import (
TODO
--
- HIGH Rate limiting: dailyLimitToRate is wrong? + TESTS
- HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
- HIGH Rate limiting: Delete visitor when tier is changed to refresh rate limiters
- HIGH Rate limiting: When ResetStats() is run, reset messagesLimiter (and others)?
- MEDIUM Rate limiting: Test daily message quota read from database initially
- MEDIUM: Races with v.user (see publishSyncEventAsync test)
- MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
- MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade)
@ -57,7 +56,6 @@ Make sure account endpoints make sense for admins
Tests:
- Payment endpoints (make mocks)
- test that the visitor is based on the IP address when a user has no tier
*/
// Server is the main server, providing the UI and API for ntfy
@ -308,7 +306,7 @@ func (s *Server) Stop() {
}
func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
v, err := s.visitor(r) // Note: Always returns v, even when error is returned
v, err := s.maybeAuthenticate(r) // Note: Always returns v, even when error is returned
if err == nil {
log.Debug("%s Dispatching request", logHTTPPrefix(v, r))
if log.IsTrace() {
@ -563,7 +561,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
if v.user != nil {
m.User = v.user.ID
}
m.Expires = time.Now().Add(v.Limits().MessagesExpiryDuration).Unix()
m.Expires = time.Now().Add(v.Limits().MessageExpiryDuration).Unix()
if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
return nil, err
}
@ -601,7 +599,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
}
v.IncrementMessages()
if s.userManager != nil && v.user != nil {
s.userManager.EnqueueStats(v.user)
s.userManager.EnqueueStats(v.user) // FIXME this makes no sense for tier-less users
}
s.mu.Lock()
s.messages++
@ -1382,8 +1380,10 @@ func (s *Server) runStatsResetter() {
log.Debug("Stats resetter: Waiting until %v to reset visitor stats", runAt)
select {
case <-timer.C:
log.Debug("Stats resetter: Running")
s.resetStats()
case <-s.closeChan:
log.Debug("Stats resetter: Stopping timer")
timer.Stop()
return
}
@ -1440,17 +1440,15 @@ func (s *Server) sendDelayedMessages() error {
return err
}
for _, m := range messages {
var v *visitor
var u *user.User
if s.userManager != nil && m.User != "" {
u, err := s.userManager.User(m.User)
u, err = s.userManager.User(m.User)
if err != nil {
log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
log.Warn("Error sending delayed message %s: %s", m.ID, err.Error())
continue
}
v = s.visitorFromUser(u, m.Sender)
} else {
v = s.visitorFromIP(m.Sender)
}
v := s.visitor(m.Sender, u)
if err := s.sendDelayedMessage(v, m); err != nil {
log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
}
@ -1588,20 +1586,16 @@ func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc
}
}
// visitor creates or retrieves a rate.Limiter for the given visitor.
// maybeAuthenticate creates or retrieves a rate.Limiter for the given visitor.
// Note that this function will always return a visitor, even if an error occurs.
func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
func (s *Server) maybeAuthenticate(r *http.Request) (v *visitor, err error) {
ip := extractIPAddress(r, s.config.BehindProxy)
var u *user.User // may stay nil if no auth header!
if u, err = s.authenticate(r); err != nil {
log.Debug("authentication failed: %s", err.Error())
err = errHTTPUnauthorized // Always return visitor, even when error occurs!
}
if u != nil {
v = s.visitorFromUser(u, ip)
} else {
v = s.visitorFromIP(ip)
}
v = s.visitor(ip, u)
v.SetUser(u) // Update visitor user with latest from database!
return v, err // Always return visitor, even when error occurs!
}
@ -1645,26 +1639,19 @@ func (s *Server) authenticateBearerAuth(value string) (user *user.User, err erro
return s.userManager.AuthenticateToken(token)
}
func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *user.User) *visitor {
func (s *Server) visitor(ip netip.Addr, user *user.User) *visitor {
s.mu.Lock()
defer s.mu.Unlock()
v, exists := s.visitors[visitorID]
id := visitorID(ip, user)
v, exists := s.visitors[id]
if !exists {
s.visitors[visitorID] = newVisitor(s.config, s.messageCache, s.userManager, ip, user)
return s.visitors[visitorID]
s.visitors[id] = newVisitor(s.config, s.messageCache, s.userManager, ip, user)
return s.visitors[id]
}
v.Keepalive()
return v
}
func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
return s.visitorFromID(fmt.Sprintf("ip:%s", ip.String()), ip, nil)
}
func (s *Server) visitorFromUser(user *user.User, ip netip.Addr) *visitor {
return s.visitorFromID(fmt.Sprintf("user:%s", user.ID), ip, user)
}
func (s *Server) writeJSON(w http.ResponseWriter, v any) error {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests