mirror of
				https://github.com/binwiederhier/ntfy.git
				synced 2025-10-31 13:02:24 +01:00 
			
		
		
		
	Add test, fails
This commit is contained in:
		
							parent
							
								
									4ab450309f
								
							
						
					
					
						commit
						29340e7e24
					
				
					 5 changed files with 89 additions and 45 deletions
				
			
		|  | @ -9,6 +9,12 @@ import ( | |||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/emersion/go-smtp" | ||||
| 	"github.com/gorilla/websocket" | ||||
| 	"golang.org/x/sync/errgroup" | ||||
| 	"heckel.io/ntfy/log" | ||||
| 	"heckel.io/ntfy/user" | ||||
| 	"heckel.io/ntfy/util" | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
|  | @ -24,13 +30,6 @@ import ( | |||
| 	"sync" | ||||
| 	"time" | ||||
| 	"unicode/utf8" | ||||
| 
 | ||||
| 	"github.com/emersion/go-smtp" | ||||
| 	"github.com/gorilla/websocket" | ||||
| 	"golang.org/x/sync/errgroup" | ||||
| 	"heckel.io/ntfy/log" | ||||
| 	"heckel.io/ntfy/user" | ||||
| 	"heckel.io/ntfy/util" | ||||
| ) | ||||
| 
 | ||||
| // Server is the main server, providing the UI and API for ntfy | ||||
|  | @ -105,15 +104,15 @@ var ( | |||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	firebaseControlTopic        = "~control"                // See Android if changed | ||||
| 	firebasePollTopic           = "~poll"                   // See iOS if changed | ||||
| 	emptyMessageBody            = "triggered"               // Used if message body is empty | ||||
| 	newMessageBody              = "New message"             // Used in poll requests as generic message | ||||
| 	defaultAttachmentMessage    = "You received a file: %s" // Used if message body is empty, and there is an attachment | ||||
| 	encodingBase64              = "base64"                  // Used mainly for binary UnifiedPush messages | ||||
| 	jsonBodyBytesLimit          = 16384 | ||||
| 	subscriberBilledTopicPrefix = "up_" | ||||
| 	subscriberBilledValidity    = 12 * time.Hour | ||||
| 	firebaseControlTopic      = "~control"                // See Android if changed | ||||
| 	firebasePollTopic         = "~poll"                   // See iOS if changed | ||||
| 	emptyMessageBody          = "triggered"               // Used if message body is empty | ||||
| 	newMessageBody            = "New message"             // Used in poll requests as generic message | ||||
| 	defaultAttachmentMessage  = "You received a file: %s" // Used if message body is empty, and there is an attachment | ||||
| 	encodingBase64            = "base64"                  // Used mainly for binary UnifiedPush messages | ||||
| 	jsonBodyBytesLimit        = 16384 | ||||
| 	unifiedPushTopicPrefix    = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber | ||||
| 	rateVisitorExpiryDuration = 12 * time.Hour | ||||
| ) | ||||
| 
 | ||||
| // WebSocket constants | ||||
|  | @ -996,7 +995,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * | |||
| 	defer cancel() | ||||
| 	subscriberIDs := make([]int, 0) | ||||
| 	for _, t := range topics { | ||||
| 		subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, subscriberBilledTopicPrefix) // temporarily do prefix as well | ||||
| 		subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) // temporarily do prefix as well | ||||
| 		subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel, subscriberRateLimited)) | ||||
| 	} | ||||
| 	defer func() { | ||||
|  | @ -1129,7 +1128,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi | |||
| 	} | ||||
| 	subscriberIDs := make([]int, 0) | ||||
| 	for _, t := range topics { | ||||
| 		subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, subscriberBilledTopicPrefix) // temporarily do prefix as well | ||||
| 		subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) // temporarily do prefix as well | ||||
| 		subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel, subscriberRateLimited)) | ||||
| 	} | ||||
| 	defer func() { | ||||
|  | @ -1162,7 +1161,6 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu | |||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	subscriberTopics = readCommaSeperatedParam(r, "subscriber-rate-limit-topics", "x-subscriber-rate-limit-topics", "srlt") | ||||
| 	return | ||||
| } | ||||
|  |  | |||
|  | @ -40,7 +40,7 @@ func (s *Server) execManager() { | |||
| 				if ev.IsTrace() { | ||||
| 					expiryMessage := "" | ||||
| 					if subs == 0 { | ||||
| 						expiryTime := time.Until(t.vRateExpires) | ||||
| 						expiryTime := time.Until(t.rateVisitorExpires) | ||||
| 						expiryMessage = ", expires in " + expiryTime.String() | ||||
| 					} | ||||
| 					ev.Trace("- topic %s: %d subscribers%s", t.ID, subs, expiryMessage) | ||||
|  |  | |||
|  | @ -25,15 +25,15 @@ func (s *Server) limitRequestsWithTopic(next handleFunc) handleFunc { | |||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		vRate := v | ||||
| 		vrate := v | ||||
| 		if topicCountsAgainst := t.Billee(); topicCountsAgainst != nil { | ||||
| 			vRate = topicCountsAgainst | ||||
| 			vrate = topicCountsAgainst | ||||
| 		} | ||||
| 		r = r.WithContext(context.WithValue(context.WithValue(r.Context(), "vRate", vRate), "topic", t)) | ||||
| 		r = r.WithContext(context.WithValue(context.WithValue(r.Context(), "vRate", vrate), "topic", t)) | ||||
| 
 | ||||
| 		if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) { | ||||
| 			return next(w, r, v) | ||||
| 		} else if !vRate.RequestAllowed() { | ||||
| 		} else if !vrate.RequestAllowed() { | ||||
| 			return errHTTPTooManyRequestsLimitRequests | ||||
| 		} | ||||
| 		return next(w, r, v) | ||||
|  |  | |||
|  | @ -1889,6 +1889,49 @@ func TestServer_AnonymousUser_And_NonTierUser_Are_Same_Visitor(t *testing.T) { | |||
| 	require.Equal(t, int64(2), account.Stats.Messages) | ||||
| } | ||||
| 
 | ||||
| func TestServer_SubscriberRateLimiting(t *testing.T) { | ||||
| 	c := newTestConfigWithAuthFile(t) | ||||
| 	c.VisitorRequestLimitBurst = 3 | ||||
| 	s := newTestServer(t, c) | ||||
| 
 | ||||
| 	subscriber1Fn := func(r *http.Request) { | ||||
| 		r.RemoteAddr = "1.2.3.4" | ||||
| 	} | ||||
| 	rr := request(t, s, "GET", "/subscriber1topic/json?poll=1", "", map[string]string{ | ||||
| 		"Subscriber-Rate-Limit-Topics": "mytopic1", | ||||
| 	}, subscriber1Fn) | ||||
| 	require.Equal(t, 200, rr.Code) | ||||
| 	require.Equal(t, "", rr.Body.String()) | ||||
| 
 | ||||
| 	subscriber2Fn := func(r *http.Request) { | ||||
| 		r.RemoteAddr = "8.7.7.1" | ||||
| 	} | ||||
| 	rr = request(t, s, "GET", "/upSUB2topic/json?poll=1", "", nil, subscriber2Fn) | ||||
| 	require.Equal(t, 200, rr.Code) | ||||
| 	require.Equal(t, "", rr.Body.String()) | ||||
| 
 | ||||
| 	for i := 0; i < 3; i++ { | ||||
| 		rr := request(t, s, "PUT", "/subscriber1topic", "some message", nil) | ||||
| 		require.Equal(t, 200, rr.Code) | ||||
| 	} | ||||
| 	rr = request(t, s, "PUT", "/subscriber1topic", "some message", nil) | ||||
| 	require.Equal(t, 429, rr.Code) | ||||
| 
 | ||||
| 	for i := 0; i < 3; i++ { | ||||
| 		rr := request(t, s, "PUT", "/upSUB2topic", "some message", nil) | ||||
| 		require.Equal(t, 200, rr.Code) // If we fail here, handlePublish is using the wrong visitor! | ||||
| 	} | ||||
| 	rr = request(t, s, "PUT", "/upSUB2topic", "some message", nil) | ||||
| 	require.Equal(t, 429, rr.Code) | ||||
| 
 | ||||
| 	for i := 0; i < 3; i++ { | ||||
| 		rr := request(t, s, "PUT", "/some-other-topic", "some message", nil) | ||||
| 		require.Equal(t, 200, rr.Code) | ||||
| 	} | ||||
| 	rr = request(t, s, "PUT", "/some-other-topic", "some message", nil) | ||||
| 	require.Equal(t, 429, rr.Code) | ||||
| } | ||||
| 
 | ||||
| func newTestConfig(t *testing.T) *Config { | ||||
| 	conf := NewConfig() | ||||
| 	conf.BaseURL = "http://127.0.0.1:12345" | ||||
|  | @ -1914,7 +1957,7 @@ func newTestServer(t *testing.T, config *Config) *Server { | |||
| 	return server | ||||
| } | ||||
| 
 | ||||
| func request(t *testing.T, s *Server, method, url, body string, headers map[string]string) *httptest.ResponseRecorder { | ||||
| func request(t *testing.T, s *Server, method, url, body string, headers map[string]string, fn ...func(r *http.Request)) *httptest.ResponseRecorder { | ||||
| 	rr := httptest.NewRecorder() | ||||
| 	req, err := http.NewRequest(method, url, strings.NewReader(body)) | ||||
| 	if err != nil { | ||||
|  | @ -1924,6 +1967,9 @@ func request(t *testing.T, s *Server, method, url, body string, headers map[stri | |||
| 	for k, v := range headers { | ||||
| 		req.Header.Set(k, v) | ||||
| 	} | ||||
| 	for _, f := range fn { | ||||
| 		f(req) | ||||
| 	} | ||||
| 	s.handle(rr, req) | ||||
| 	return rr | ||||
| } | ||||
|  |  | |||
|  | @ -11,11 +11,11 @@ import ( | |||
| // topic represents a channel to which subscribers can subscribe, and publishers | ||||
| // can publish a message | ||||
| type topic struct { | ||||
| 	ID           string | ||||
| 	subscribers  map[int]*topicSubscriber | ||||
| 	vRate        *visitor | ||||
| 	vRateExpires time.Time | ||||
| 	mu           sync.Mutex | ||||
| 	ID                 string | ||||
| 	subscribers        map[int]*topicSubscriber | ||||
| 	rateVisitor        *visitor | ||||
| 	rateVisitorExpires time.Time | ||||
| 	mu                 sync.RWMutex | ||||
| } | ||||
| 
 | ||||
| type topicSubscriber struct { | ||||
|  | @ -49,9 +49,9 @@ func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func(), subscri | |||
| 	} | ||||
| 
 | ||||
| 	// if no subscriber is already handling the rate limit | ||||
| 	if t.vRate == nil && subscriberRateLimit { | ||||
| 		t.vRate = visitor | ||||
| 		t.vRateExpires = time.Time{} | ||||
| 	if t.rateVisitor == nil && subscriberRateLimit { | ||||
| 		t.rateVisitor = visitor | ||||
| 		t.rateVisitorExpires = time.Time{} | ||||
| 	} | ||||
| 
 | ||||
| 	return subscriberID | ||||
|  | @ -61,16 +61,16 @@ func (t *topic) Stale() bool { | |||
| 	t.mu.Lock() | ||||
| 	defer t.mu.Unlock() | ||||
| 	// if Time is initialized (not the zero value) and the expiry time has passed | ||||
| 	if !t.vRateExpires.IsZero() && t.vRateExpires.Before(time.Now()) { | ||||
| 		t.vRate = nil | ||||
| 	if !t.rateVisitorExpires.IsZero() && t.rateVisitorExpires.Before(time.Now()) { | ||||
| 		t.rateVisitor = nil | ||||
| 	} | ||||
| 	return len(t.subscribers) == 0 && t.vRate == nil | ||||
| 	return len(t.subscribers) == 0 && t.rateVisitor == nil | ||||
| } | ||||
| 
 | ||||
| func (t *topic) Billee() *visitor { | ||||
| 	t.mu.Lock() | ||||
| 	defer t.mu.Unlock() | ||||
| 	return t.vRate | ||||
| 	t.mu.RLock() | ||||
| 	defer t.mu.RUnlock() | ||||
| 	return t.rateVisitor | ||||
| } | ||||
| 
 | ||||
| // Unsubscribe removes the subscription from the list of subscribers | ||||
|  | @ -84,16 +84,16 @@ func (t *topic) Unsubscribe(id int) { | |||
| 	// look for an active subscriber (in random order) that wants to handle the rate limit | ||||
| 	for _, v := range t.subscribers { | ||||
| 		if v.subscriberRateLimit { | ||||
| 			t.vRate = v.visitor | ||||
| 			t.vRateExpires = time.Time{} | ||||
| 			t.rateVisitor = v.visitor | ||||
| 			t.rateVisitorExpires = time.Time{} | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// if no active subscriber is found, count it towards the leaving subscriber | ||||
| 	if deletingSub.subscriberRateLimit { | ||||
| 		t.vRate = deletingSub.visitor | ||||
| 		t.vRateExpires = time.Now().Add(subscriberBilledValidity) | ||||
| 		t.rateVisitor = deletingSub.visitor | ||||
| 		t.rateVisitorExpires = time.Now().Add(rateVisitorExpiryDuration) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | @ -123,8 +123,8 @@ func (t *topic) Publish(v *visitor, m *message) error { | |||
| 
 | ||||
| // SubscribersCount returns the number of subscribers to this topic | ||||
| func (t *topic) SubscribersCount() int { | ||||
| 	t.mu.Lock() | ||||
| 	defer t.mu.Unlock() | ||||
| 	t.mu.RLock() | ||||
| 	defer t.mu.RUnlock() | ||||
| 	return len(t.subscribers) | ||||
| } | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 binwiederhier
						binwiederhier