mirror of
https://github.com/binwiederhier/ntfy.git
synced 2025-06-24 13:28:19 +02:00
Rate limits make sense now!
This commit is contained in:
parent
a036814d98
commit
c874a641df
17 changed files with 365 additions and 205 deletions
server
|
@ -38,10 +38,9 @@ import (
|
|||
TODO
|
||||
--
|
||||
|
||||
- HIGH Rate limiting: dailyLimitToRate is wrong? + TESTS
|
||||
- HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
|
||||
- HIGH Rate limiting: Delete visitor when tier is changed to refresh rate limiters
|
||||
- HIGH Rate limiting: When ResetStats() is run, reset messagesLimiter (and others)?
|
||||
- MEDIUM Rate limiting: Test daily message quota read from database initially
|
||||
- MEDIUM: Races with v.user (see publishSyncEventAsync test)
|
||||
- MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
|
||||
- MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade)
|
||||
|
@ -57,7 +56,6 @@ Make sure account endpoints make sense for admins
|
|||
|
||||
Tests:
|
||||
- Payment endpoints (make mocks)
|
||||
- test that the visitor is based on the IP address when a user has no tier
|
||||
*/
|
||||
|
||||
// Server is the main server, providing the UI and API for ntfy
|
||||
|
@ -308,7 +306,7 @@ func (s *Server) Stop() {
|
|||
}
|
||||
|
||||
func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
|
||||
v, err := s.visitor(r) // Note: Always returns v, even when error is returned
|
||||
v, err := s.maybeAuthenticate(r) // Note: Always returns v, even when error is returned
|
||||
if err == nil {
|
||||
log.Debug("%s Dispatching request", logHTTPPrefix(v, r))
|
||||
if log.IsTrace() {
|
||||
|
@ -563,7 +561,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
|||
if v.user != nil {
|
||||
m.User = v.user.ID
|
||||
}
|
||||
m.Expires = time.Now().Add(v.Limits().MessagesExpiryDuration).Unix()
|
||||
m.Expires = time.Now().Add(v.Limits().MessageExpiryDuration).Unix()
|
||||
if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -601,7 +599,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
|||
}
|
||||
v.IncrementMessages()
|
||||
if s.userManager != nil && v.user != nil {
|
||||
s.userManager.EnqueueStats(v.user)
|
||||
s.userManager.EnqueueStats(v.user) // FIXME this makes no sense for tier-less users
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.messages++
|
||||
|
@ -1382,8 +1380,10 @@ func (s *Server) runStatsResetter() {
|
|||
log.Debug("Stats resetter: Waiting until %v to reset visitor stats", runAt)
|
||||
select {
|
||||
case <-timer.C:
|
||||
log.Debug("Stats resetter: Running")
|
||||
s.resetStats()
|
||||
case <-s.closeChan:
|
||||
log.Debug("Stats resetter: Stopping timer")
|
||||
timer.Stop()
|
||||
return
|
||||
}
|
||||
|
@ -1440,17 +1440,15 @@ func (s *Server) sendDelayedMessages() error {
|
|||
return err
|
||||
}
|
||||
for _, m := range messages {
|
||||
var v *visitor
|
||||
var u *user.User
|
||||
if s.userManager != nil && m.User != "" {
|
||||
u, err := s.userManager.User(m.User)
|
||||
u, err = s.userManager.User(m.User)
|
||||
if err != nil {
|
||||
log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
|
||||
log.Warn("Error sending delayed message %s: %s", m.ID, err.Error())
|
||||
continue
|
||||
}
|
||||
v = s.visitorFromUser(u, m.Sender)
|
||||
} else {
|
||||
v = s.visitorFromIP(m.Sender)
|
||||
}
|
||||
v := s.visitor(m.Sender, u)
|
||||
if err := s.sendDelayedMessage(v, m); err != nil {
|
||||
log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
|
||||
}
|
||||
|
@ -1588,20 +1586,16 @@ func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc
|
|||
}
|
||||
}
|
||||
|
||||
// visitor creates or retrieves a rate.Limiter for the given visitor.
|
||||
// maybeAuthenticate creates or retrieves a rate.Limiter for the given visitor.
|
||||
// Note that this function will always return a visitor, even if an error occurs.
|
||||
func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
|
||||
func (s *Server) maybeAuthenticate(r *http.Request) (v *visitor, err error) {
|
||||
ip := extractIPAddress(r, s.config.BehindProxy)
|
||||
var u *user.User // may stay nil if no auth header!
|
||||
if u, err = s.authenticate(r); err != nil {
|
||||
log.Debug("authentication failed: %s", err.Error())
|
||||
err = errHTTPUnauthorized // Always return visitor, even when error occurs!
|
||||
}
|
||||
if u != nil {
|
||||
v = s.visitorFromUser(u, ip)
|
||||
} else {
|
||||
v = s.visitorFromIP(ip)
|
||||
}
|
||||
v = s.visitor(ip, u)
|
||||
v.SetUser(u) // Update visitor user with latest from database!
|
||||
return v, err // Always return visitor, even when error occurs!
|
||||
}
|
||||
|
@ -1645,26 +1639,19 @@ func (s *Server) authenticateBearerAuth(value string) (user *user.User, err erro
|
|||
return s.userManager.AuthenticateToken(token)
|
||||
}
|
||||
|
||||
func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *user.User) *visitor {
|
||||
func (s *Server) visitor(ip netip.Addr, user *user.User) *visitor {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
v, exists := s.visitors[visitorID]
|
||||
id := visitorID(ip, user)
|
||||
v, exists := s.visitors[id]
|
||||
if !exists {
|
||||
s.visitors[visitorID] = newVisitor(s.config, s.messageCache, s.userManager, ip, user)
|
||||
return s.visitors[visitorID]
|
||||
s.visitors[id] = newVisitor(s.config, s.messageCache, s.userManager, ip, user)
|
||||
return s.visitors[id]
|
||||
}
|
||||
v.Keepalive()
|
||||
return v
|
||||
}
|
||||
|
||||
func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
|
||||
return s.visitorFromID(fmt.Sprintf("ip:%s", ip.String()), ip, nil)
|
||||
}
|
||||
|
||||
func (s *Server) visitorFromUser(user *user.User, ip netip.Addr) *visitor {
|
||||
return s.visitorFromID(fmt.Sprintf("user:%s", user.ID), ip, user)
|
||||
}
|
||||
|
||||
func (s *Server) writeJSON(w http.ResponseWriter, v any) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue