mirror of
https://github.com/binwiederhier/ntfy.git
synced 2024-11-04 19:04:15 +01:00
rate limiting impl 2.0?
This commit is contained in:
parent
36685e9df9
commit
1655f584f9
4 changed files with 97 additions and 61 deletions
|
@ -3,8 +3,9 @@ package server
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"heckel.io/ntfy/log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"heckel.io/ntfy/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// errHTTP is a generic HTTP error for any non-200 HTTP error
|
// errHTTP is a generic HTTP error for any non-200 HTTP error
|
||||||
|
@ -92,5 +93,4 @@ var (
|
||||||
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
|
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
|
||||||
errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", ""}
|
errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", ""}
|
||||||
errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/"}
|
errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/"}
|
||||||
errHTTPWontStoreMessage = &errHTTP{50701, http.StatusInsufficientStorage, "topic is inactive; no device available to recieve message", ""}
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,12 +9,6 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/emersion/go-smtp"
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
"golang.org/x/sync/errgroup"
|
|
||||||
"heckel.io/ntfy/log"
|
|
||||||
"heckel.io/ntfy/user"
|
|
||||||
"heckel.io/ntfy/util"
|
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -30,6 +24,13 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"github.com/emersion/go-smtp"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
"heckel.io/ntfy/log"
|
||||||
|
"heckel.io/ntfy/user"
|
||||||
|
"heckel.io/ntfy/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -605,23 +606,23 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
v_old := v
|
vRate := v
|
||||||
if strings.HasPrefix(t.ID, subscriberBilledTopicPrefix) {
|
if topicCountsAgainst := t.Billee(); topicCountsAgainst != nil {
|
||||||
v = t.getBillee()
|
vRate = topicCountsAgainst
|
||||||
if v == nil {
|
|
||||||
return nil, errHTTPWontStoreMessage
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !vRate.MessageAllowed() {
|
||||||
|
vRate = v
|
||||||
if !v.MessageAllowed() {
|
if !v.MessageAllowed() {
|
||||||
return nil, errHTTPTooManyRequestsLimitMessages
|
return nil, errHTTPTooManyRequestsLimitMessages
|
||||||
}
|
}
|
||||||
|
}
|
||||||
body, err := util.Peek(r.Body, s.config.MessageLimit)
|
body, err := util.Peek(r.Body, s.config.MessageLimit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
m := newDefaultMessage(t.ID, "")
|
m := newDefaultMessage(t.ID, "")
|
||||||
cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, v, m)
|
cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, vRate, m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -630,7 +631,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
||||||
}
|
}
|
||||||
m.Sender = v.IP()
|
m.Sender = v.IP()
|
||||||
m.User = v.MaybeUserID()
|
m.User = v.MaybeUserID()
|
||||||
m.Expires = time.Unix(m.Time, 0).Add(v.Limits().MessageExpiryDuration).Unix()
|
m.Expires = time.Unix(m.Time, 0).Add(vRate.Limits().MessageExpiryDuration).Unix()
|
||||||
if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
|
if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -638,18 +639,18 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
||||||
m.Message = emptyMessageBody
|
m.Message = emptyMessageBody
|
||||||
}
|
}
|
||||||
delayed := m.Time > time.Now().Unix()
|
delayed := m.Time > time.Now().Unix()
|
||||||
logvrm(v, r, m).
|
logvrm(vRate, r, m).
|
||||||
Tag(tagPublish).
|
Tag(tagPublish).
|
||||||
Fields(log.Context{
|
Fields(log.Context{
|
||||||
"message_delayed": delayed,
|
"message_delayed": delayed,
|
||||||
"message_firebase": firebase,
|
"message_firebase": firebase,
|
||||||
"message_unifiedpush": unifiedpush,
|
"message_unifiedpush": unifiedpush,
|
||||||
"message_email": email,
|
"message_email": email,
|
||||||
|
"message_subscriber_rate_limited": vRate != v,
|
||||||
}).
|
}).
|
||||||
Debug("Received message")
|
Debug("Received message")
|
||||||
//Where should I log the original visitor vs the billing visitor
|
|
||||||
if log.IsTrace() {
|
if log.IsTrace() {
|
||||||
logvrm(v_old, r, m).
|
logvrm(vRate, r, m).
|
||||||
Tag(tagPublish).
|
Tag(tagPublish).
|
||||||
Field("message_body", util.MaybeMarshalJSON(m)).
|
Field("message_body", util.MaybeMarshalJSON(m)).
|
||||||
Trace("Message body")
|
Trace("Message body")
|
||||||
|
@ -659,7 +660,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if s.firebaseClient != nil && firebase {
|
if s.firebaseClient != nil && firebase {
|
||||||
go s.sendToFirebase(v, m)
|
go s.sendToFirebase(vRate, m)
|
||||||
}
|
}
|
||||||
if s.smtpSender != nil && email != "" {
|
if s.smtpSender != nil && email != "" {
|
||||||
go s.sendEmail(v, m, email)
|
go s.sendEmail(v, m, email)
|
||||||
|
@ -745,7 +746,7 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) {
|
func (s *Server) parsePublishParams(r *http.Request, vRate *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) {
|
||||||
cache = readBoolParam(r, true, "x-cache", "cache")
|
cache = readBoolParam(r, true, "x-cache", "cache")
|
||||||
firebase = readBoolParam(r, true, "x-firebase", "firebase")
|
firebase = readBoolParam(r, true, "x-firebase", "firebase")
|
||||||
m.Title = readParam(r, "x-title", "title", "t")
|
m.Title = readParam(r, "x-title", "title", "t")
|
||||||
|
@ -785,7 +786,7 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca
|
||||||
}
|
}
|
||||||
email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
|
email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
|
||||||
if email != "" {
|
if email != "" {
|
||||||
if !v.EmailAllowed() {
|
if !vRate.EmailAllowed() {
|
||||||
return false, false, "", false, errHTTPTooManyRequestsLimitEmails
|
return false, false, "", false, errHTTPTooManyRequestsLimitEmails
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -800,13 +801,7 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, false, "", false, errHTTPBadRequestPriorityInvalid
|
return false, false, "", false, errHTTPBadRequestPriorityInvalid
|
||||||
}
|
}
|
||||||
tagsStr := readParam(r, "x-tags", "tags", "tag", "ta")
|
m.Tags = readCommaSeperatedParam(r, "x-tags", "tags", "tag", "ta")
|
||||||
if tagsStr != "" {
|
|
||||||
m.Tags = make([]string, 0)
|
|
||||||
for _, s := range util.SplitNoEmpty(tagsStr, ",") {
|
|
||||||
m.Tags = append(m.Tags, strings.TrimSpace(s))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
delayStr := readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in")
|
delayStr := readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in")
|
||||||
if delayStr != "" {
|
if delayStr != "" {
|
||||||
if !cache {
|
if !cache {
|
||||||
|
@ -996,7 +991,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
poll, since, scheduled, filters, err := parseSubscribeParams(r)
|
poll, since, scheduled, filters, subscriberRateTopics, err := parseSubscribeParams(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -1035,7 +1030,8 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
||||||
defer cancel()
|
defer cancel()
|
||||||
subscriberIDs := make([]int, 0)
|
subscriberIDs := make([]int, 0)
|
||||||
for _, t := range topics {
|
for _, t := range topics {
|
||||||
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel))
|
subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, subscriberBilledTopicPrefix) // temporarily do prefix as well
|
||||||
|
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel, subscriberRateLimited))
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
for i, subscriberID := range subscriberIDs {
|
for i, subscriberID := range subscriberIDs {
|
||||||
|
@ -1078,7 +1074,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
poll, since, scheduled, filters, err := parseSubscribeParams(r)
|
poll, since, scheduled, filters, subscriberRateTopics, err := parseSubscribeParams(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -1167,7 +1163,8 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
||||||
}
|
}
|
||||||
subscriberIDs := make([]int, 0)
|
subscriberIDs := make([]int, 0)
|
||||||
for _, t := range topics {
|
for _, t := range topics {
|
||||||
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel))
|
subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, subscriberBilledTopicPrefix) // temporarily do prefix as well
|
||||||
|
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel, subscriberRateLimited))
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
for i, subscriberID := range subscriberIDs {
|
for i, subscriberID := range subscriberIDs {
|
||||||
|
@ -1188,7 +1185,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, err error) {
|
func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, subscriberTopics []string, err error) {
|
||||||
poll = readBoolParam(r, false, "x-poll", "poll", "po")
|
poll = readBoolParam(r, false, "x-poll", "poll", "po")
|
||||||
scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched")
|
scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched")
|
||||||
since, err = parseSince(r, poll)
|
since, err = parseSince(r, poll)
|
||||||
|
@ -1199,6 +1196,8 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
subscriberTopics = readCommaSeperatedParam(r, "subscriber-rate-limit-topics", "x-subscriber-rate-limit-topics", "srlt")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ type topicSubscriber struct {
|
||||||
subscriber subscriber
|
subscriber subscriber
|
||||||
visitor *visitor // User ID associated with this subscription, may be empty
|
visitor *visitor // User ID associated with this subscription, may be empty
|
||||||
cancel func()
|
cancel func()
|
||||||
|
subscriberRateLimit bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// subscriber is a function that is called for every new message on a topic
|
// subscriber is a function that is called for every new message on a topic
|
||||||
|
@ -36,7 +37,7 @@ func newTopic(id string) *topic {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Subscribe subscribes to this topic
|
// Subscribe subscribes to this topic
|
||||||
func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func()) int {
|
func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func(), subscriberRateLimit bool) int {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
defer t.mu.Unlock()
|
defer t.mu.Unlock()
|
||||||
subscriberID := rand.Int()
|
subscriberID := rand.Int()
|
||||||
|
@ -44,23 +45,28 @@ func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func()) int {
|
||||||
visitor: visitor, // May be empty
|
visitor: visitor, // May be empty
|
||||||
subscriber: s,
|
subscriber: s,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
|
subscriberRateLimit: subscriberRateLimit,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if no subscriber is already handling the rate limit
|
||||||
|
if t.lastVisitor == nil && subscriberRateLimit {
|
||||||
|
t.lastVisitor = visitor
|
||||||
|
t.lastVisitorExpires = time.Time{}
|
||||||
|
}
|
||||||
|
|
||||||
return subscriberID
|
return subscriberID
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *topic) Stale() bool {
|
func (t *topic) Stale() bool {
|
||||||
return t.getBillee() == nil
|
// if Time is initialized (not the zero value) and the expiry time has passed
|
||||||
}
|
if !t.lastVisitorExpires.IsZero() && t.lastVisitorExpires.Before(time.Now()) {
|
||||||
|
|
||||||
func (t *topic) getBillee() *visitor {
|
|
||||||
for _, this_subscriber := range t.subscribers {
|
|
||||||
return this_subscriber.visitor
|
|
||||||
}
|
|
||||||
if t.lastVisitor != nil && t.lastVisitorExpires.After(time.Now()) {
|
|
||||||
t.lastVisitor = nil
|
t.lastVisitor = nil
|
||||||
}
|
}
|
||||||
return t.lastVisitor
|
return len(t.subscribers) == 0 && t.lastVisitor == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *topic) Billee() *visitor {
|
||||||
|
return t.lastVisitor
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unsubscribe removes the subscription from the list of subscribers
|
// Unsubscribe removes the subscription from the list of subscribers
|
||||||
|
@ -68,11 +74,23 @@ func (t *topic) Unsubscribe(id int) {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
defer t.mu.Unlock()
|
defer t.mu.Unlock()
|
||||||
|
|
||||||
if len(t.subscribers) == 1 {
|
deletingSub := t.subscribers[id]
|
||||||
t.lastVisitor = t.subscribers[id].visitor
|
delete(t.subscribers, id)
|
||||||
|
|
||||||
|
// look for an active subscriber (in random order) that wants to handle the rate limit
|
||||||
|
for _, v := range t.subscribers {
|
||||||
|
if v.subscriberRateLimit {
|
||||||
|
t.lastVisitor = v.visitor
|
||||||
|
t.lastVisitorExpires = time.Time{}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if no active subscriber is found, count it towards the leaving subscriber
|
||||||
|
if deletingSub.subscriberRateLimit {
|
||||||
|
t.lastVisitor = deletingSub.visitor
|
||||||
t.lastVisitorExpires = time.Now().Add(subscriberBilledValidity)
|
t.lastVisitorExpires = time.Now().Add(subscriberBilledValidity)
|
||||||
}
|
}
|
||||||
delete(t.subscribers, id)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Publish asynchronously publishes to all subscribers
|
// Publish asynchronously publishes to all subscribers
|
||||||
|
|
|
@ -1,12 +1,13 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"heckel.io/ntfy/log"
|
|
||||||
"heckel.io/ntfy/util"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"heckel.io/ntfy/log"
|
||||||
|
"heckel.io/ntfy/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool {
|
func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool {
|
||||||
|
@ -17,6 +18,17 @@ func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool {
|
||||||
return value == "1" || value == "yes" || value == "true"
|
return value == "1" || value == "yes" || value == "true"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func readCommaSeperatedParam(r *http.Request, names ...string) (params []string) {
|
||||||
|
paramStr := readParam(r, names...)
|
||||||
|
if paramStr != "" {
|
||||||
|
params = make([]string, 0)
|
||||||
|
for _, s := range util.SplitNoEmpty(paramStr, ",") {
|
||||||
|
params = append(params, strings.TrimSpace(s))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return params
|
||||||
|
}
|
||||||
|
|
||||||
func readParam(r *http.Request, names ...string) string {
|
func readParam(r *http.Request, names ...string) string {
|
||||||
value := readHeaderParam(r, names...)
|
value := readHeaderParam(r, names...)
|
||||||
if value != "" {
|
if value != "" {
|
||||||
|
@ -35,6 +47,13 @@ func readHeaderParam(r *http.Request, names ...string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func readHeaderParamValues(r *http.Request, names ...string) (values []string) {
|
||||||
|
for _, name := range names {
|
||||||
|
values = append(values, r.Header.Values(name)...)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func readQueryParam(r *http.Request, names ...string) string {
|
func readQueryParam(r *http.Request, names ...string) string {
|
||||||
for _, name := range names {
|
for _, name := range names {
|
||||||
value := r.URL.Query().Get(strings.ToLower(name))
|
value := r.URL.Query().Get(strings.ToLower(name))
|
||||||
|
|
Loading…
Reference in a new issue