mirror of
				https://github.com/binwiederhier/ntfy.git
				synced 2025-11-04 06:50:32 +01:00 
			
		
		
		
	Polishing
This commit is contained in:
		
							parent
							
								
									8eae44ea61
								
							
						
					
					
						commit
						2329695a47
					
				
					 5 changed files with 88 additions and 44 deletions
				
			
		| 
						 | 
				
			
			@ -570,14 +570,8 @@ func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) {
 | 
			
		||||
	vrate, ok := r.Context().Value(contextRateVisitor).(*visitor)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return nil, errHTTPInternalError
 | 
			
		||||
	}
 | 
			
		||||
	t, ok := r.Context().Value(contextTopic).(*topic)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return nil, errHTTPInternalError
 | 
			
		||||
	}
 | 
			
		||||
	t := fromContext[topic](r, contextTopic)
 | 
			
		||||
	vrate := fromContext[visitor](r, contextRateVisitor)
 | 
			
		||||
	if !vrate.MessageAllowed() {
 | 
			
		||||
		return nil, errHTTPTooManyRequestsLimitMessages
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -586,10 +580,13 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
 | 
			
		|||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	m := newDefaultMessage(t.ID, "")
 | 
			
		||||
	cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, vrate, m)
 | 
			
		||||
	cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, m)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	if email != "" && !vrate.EmailAllowed() {
 | 
			
		||||
		return nil, errHTTPTooManyRequestsLimitEmails
 | 
			
		||||
	}
 | 
			
		||||
	if m.PollID != "" {
 | 
			
		||||
		m = newPollRequestMessage(t.ID, m.PollID)
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -605,13 +602,15 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
 | 
			
		|||
		m.Message = emptyMessageBody
 | 
			
		||||
	}
 | 
			
		||||
	delayed := m.Time > time.Now().Unix()
 | 
			
		||||
	ev := logvrm(vrate, r, m).
 | 
			
		||||
	ev := logvrm(v, r, m).
 | 
			
		||||
		Tag(tagPublish).
 | 
			
		||||
		Fields(log.Context{
 | 
			
		||||
			"message_delayed":     delayed,
 | 
			
		||||
			"message_firebase":    firebase,
 | 
			
		||||
			"message_unifiedpush": unifiedpush,
 | 
			
		||||
			"message_email":       email,
 | 
			
		||||
			"rate_visitor_ip":     vrate.IP().String(),
 | 
			
		||||
			"rate_user_id":        vrate.MaybeUserID(),
 | 
			
		||||
		})
 | 
			
		||||
	if ev.IsTrace() {
 | 
			
		||||
		ev.Field("message_body", util.MaybeMarshalJSON(m)).Trace("Received message")
 | 
			
		||||
| 
						 | 
				
			
			@ -623,7 +622,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
 | 
			
		|||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		if s.firebaseClient != nil && firebase {
 | 
			
		||||
			go s.sendToFirebase(vrate, m)
 | 
			
		||||
			go s.sendToFirebase(v, m)
 | 
			
		||||
		}
 | 
			
		||||
		if s.smtpSender != nil && email != "" {
 | 
			
		||||
			go s.sendEmail(v, m, email)
 | 
			
		||||
| 
						 | 
				
			
			@ -708,7 +707,7 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) {
 | 
			
		|||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) parsePublishParams(r *http.Request, vrate *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) {
 | 
			
		||||
func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) {
 | 
			
		||||
	cache = readBoolParam(r, true, "x-cache", "cache")
 | 
			
		||||
	firebase = readBoolParam(r, true, "x-firebase", "firebase")
 | 
			
		||||
	m.Title = readParam(r, "x-title", "title", "t")
 | 
			
		||||
| 
						 | 
				
			
			@ -747,11 +746,6 @@ func (s *Server) parsePublishParams(r *http.Request, vrate *visitor, m *message)
 | 
			
		|||
		m.Icon = icon
 | 
			
		||||
	}
 | 
			
		||||
	email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
 | 
			
		||||
	if email != "" {
 | 
			
		||||
		if !vrate.EmailAllowed() {
 | 
			
		||||
			return false, false, "", false, errHTTPTooManyRequestsLimitEmails
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if s.smtpSender == nil && email != "" {
 | 
			
		||||
		return false, false, "", false, errHTTPBadRequestEmailDisabled
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -993,7 +987,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
 | 
			
		|||
	defer cancel()
 | 
			
		||||
	subscriberIDs := make([]int, 0)
 | 
			
		||||
	for _, t := range topics {
 | 
			
		||||
		subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel))
 | 
			
		||||
		subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel))
 | 
			
		||||
	}
 | 
			
		||||
	defer func() {
 | 
			
		||||
		for i, subscriberID := range subscriberIDs {
 | 
			
		||||
| 
						 | 
				
			
			@ -1126,7 +1120,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
 | 
			
		|||
	}
 | 
			
		||||
	subscriberIDs := make([]int, 0)
 | 
			
		||||
	for _, t := range topics {
 | 
			
		||||
		subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel))
 | 
			
		||||
		subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel))
 | 
			
		||||
	}
 | 
			
		||||
	defer func() {
 | 
			
		||||
		for i, subscriberID := range subscriberIDs {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,7 +3,6 @@ package server
 | 
			
		|||
import (
 | 
			
		||||
	"heckel.io/ntfy/log"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (s *Server) execManager() {
 | 
			
		||||
| 
						 | 
				
			
			@ -38,16 +37,23 @@ func (s *Server) execManager() {
 | 
			
		|||
				subs := t.SubscribersCount()
 | 
			
		||||
				ev := log.Tag(tagManager)
 | 
			
		||||
				if ev.IsTrace() {
 | 
			
		||||
					expiryMessage := ""
 | 
			
		||||
					if subs == 0 {
 | 
			
		||||
						expiryTime := time.Until(t.expires)
 | 
			
		||||
						expiryMessage = ", expires in " + expiryTime.String()
 | 
			
		||||
					vrate := t.RateVisitor()
 | 
			
		||||
					if vrate != nil {
 | 
			
		||||
						ev.Fields(log.Context{
 | 
			
		||||
							"rate_visitor_ip":      vrate.IP(),
 | 
			
		||||
							"rate_visitor_user_id": vrate.MaybeUserID(),
 | 
			
		||||
						})
 | 
			
		||||
					}
 | 
			
		||||
					ev.Trace("- topic %s: %d subscribers%s", t.ID, subs, expiryMessage)
 | 
			
		||||
					ev.
 | 
			
		||||
						Fields(log.Context{
 | 
			
		||||
							"message_topic":             t.ID,
 | 
			
		||||
							"message_topic_subscribers": subs,
 | 
			
		||||
						}).
 | 
			
		||||
						Trace("- topic %s: %d subscribers", t.ID, subs)
 | 
			
		||||
				}
 | 
			
		||||
				msgs, exists := messageCounts[t.ID]
 | 
			
		||||
				if t.Stale() && (!exists || msgs == 0) {
 | 
			
		||||
					log.Tag(tagManager).Trace("Deleting empty topic %s", t.ID)
 | 
			
		||||
					log.Tag(tagManager).Field("message_topic", t.ID).Trace("Deleting empty topic %s", t.ID)
 | 
			
		||||
					emptyTopics++
 | 
			
		||||
					delete(s.topics, t.ID)
 | 
			
		||||
					continue
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2030,7 +2030,40 @@ func TestServer_Matrix_SubscriberRateLimiting_UP_Only(t *testing.T) {
 | 
			
		|||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// FIXME add test for rate visitor expiration
 | 
			
		||||
func TestServer_SubscriberRateLimiting_VisitorExpiration(t *testing.T) {
 | 
			
		||||
	c := newTestConfig(t)
 | 
			
		||||
	c.VisitorRequestLimitBurst = 3
 | 
			
		||||
	s := newTestServer(t, c)
 | 
			
		||||
 | 
			
		||||
	// "Register" rate visitor
 | 
			
		||||
	subscriberFn := func(r *http.Request) {
 | 
			
		||||
		r.RemoteAddr = "1.2.3.4"
 | 
			
		||||
	}
 | 
			
		||||
	rr := request(t, s, "GET", "/mytopic/json?poll=1", "", map[string]string{
 | 
			
		||||
		"rate-topics": "*",
 | 
			
		||||
	}, subscriberFn)
 | 
			
		||||
	require.Equal(t, 200, rr.Code)
 | 
			
		||||
	require.Equal(t, "1.2.3.4", s.topics["mytopic"].rateVisitor.ip.String())
 | 
			
		||||
	require.Equal(t, s.visitors["ip:1.2.3.4"], s.topics["mytopic"].rateVisitor)
 | 
			
		||||
 | 
			
		||||
	// Publish message, observe rate visitor tokens being decreased
 | 
			
		||||
	response := request(t, s, "POST", "/mytopic", "some message", nil)
 | 
			
		||||
	require.Equal(t, 200, response.Code)
 | 
			
		||||
	require.Equal(t, int64(0), s.visitors["ip:9.9.9.9"].messagesLimiter.Value())
 | 
			
		||||
	require.Equal(t, int64(1), s.topics["mytopic"].rateVisitor.messagesLimiter.Value())
 | 
			
		||||
	require.Equal(t, s.visitors["ip:1.2.3.4"], s.topics["mytopic"].rateVisitor)
 | 
			
		||||
 | 
			
		||||
	// Expire visitor
 | 
			
		||||
	s.visitors["ip:1.2.3.4"].seen = time.Now().Add(-1 * 25 * time.Hour)
 | 
			
		||||
	s.pruneVisitors()
 | 
			
		||||
 | 
			
		||||
	// Publish message again, observe that rateVisitor is not used anymore and is reset
 | 
			
		||||
	response = request(t, s, "POST", "/mytopic", "some message", nil)
 | 
			
		||||
	require.Equal(t, 200, response.Code)
 | 
			
		||||
	require.Equal(t, int64(1), s.visitors["ip:9.9.9.9"].messagesLimiter.Value())
 | 
			
		||||
	require.Nil(t, s.topics["mytopic"].rateVisitor)
 | 
			
		||||
	require.Nil(t, s.visitors["ip:1.2.3.4"])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newTestConfig(t *testing.T) *Config {
 | 
			
		||||
	conf := NewConfig()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4,11 +4,6 @@ import (
 | 
			
		|||
	"heckel.io/ntfy/log"
 | 
			
		||||
	"math/rand"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	topicExpiryDuration = 6 * time.Hour
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// topic represents a channel to which subscribers can subscribe, and publishers
 | 
			
		||||
| 
						 | 
				
			
			@ -17,13 +12,12 @@ type topic struct {
 | 
			
		|||
	ID          string
 | 
			
		||||
	subscribers map[int]*topicSubscriber
 | 
			
		||||
	rateVisitor *visitor
 | 
			
		||||
	expires     time.Time
 | 
			
		||||
	mu          sync.RWMutex
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type topicSubscriber struct {
 | 
			
		||||
	userID     string // User ID associated with this subscription, may be empty
 | 
			
		||||
	subscriber subscriber
 | 
			
		||||
	visitor    *visitor // User ID associated with this subscription, may be empty
 | 
			
		||||
	cancel     func()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -39,12 +33,12 @@ func newTopic(id string) *topic {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
// Subscribe subscribes to this topic
 | 
			
		||||
func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func()) int {
 | 
			
		||||
func (t *topic) Subscribe(s subscriber, userID string, cancel func()) int {
 | 
			
		||||
	t.mu.Lock()
 | 
			
		||||
	defer t.mu.Unlock()
 | 
			
		||||
	subscriberID := rand.Int()
 | 
			
		||||
	t.subscribers[subscriberID] = &topicSubscriber{
 | 
			
		||||
		visitor:    visitor, // May be empty
 | 
			
		||||
		userID:     userID, // May be empty
 | 
			
		||||
		subscriber: s,
 | 
			
		||||
		cancel:     cancel,
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -54,7 +48,10 @@ func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func()) int {
 | 
			
		|||
func (t *topic) Stale() bool {
 | 
			
		||||
	t.mu.Lock()
 | 
			
		||||
	defer t.mu.Unlock()
 | 
			
		||||
	return len(t.subscribers) == 0 && t.expires.Before(time.Now())
 | 
			
		||||
	if t.rateVisitor != nil && !t.rateVisitor.Stale() {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	return len(t.subscribers) == 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *topic) SetRateVisitor(v *visitor) {
 | 
			
		||||
| 
						 | 
				
			
			@ -66,6 +63,9 @@ func (t *topic) SetRateVisitor(v *visitor) {
 | 
			
		|||
func (t *topic) RateVisitor() *visitor {
 | 
			
		||||
	t.mu.Lock()
 | 
			
		||||
	defer t.mu.Unlock()
 | 
			
		||||
	if t.rateVisitor != nil && t.rateVisitor.Stale() {
 | 
			
		||||
		t.rateVisitor = nil
 | 
			
		||||
	}
 | 
			
		||||
	return t.rateVisitor
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -74,9 +74,6 @@ func (t *topic) Unsubscribe(id int) {
 | 
			
		|||
	t.mu.Lock()
 | 
			
		||||
	defer t.mu.Unlock()
 | 
			
		||||
	delete(t.subscribers, id)
 | 
			
		||||
	if len(t.subscribers) == 0 {
 | 
			
		||||
		t.expires = time.Now().Add(topicExpiryDuration)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Publish asynchronously publishes to all subscribers
 | 
			
		||||
| 
						 | 
				
			
			@ -115,9 +112,14 @@ func (t *topic) CancelSubscribers(exceptUserID string) {
 | 
			
		|||
	t.mu.Lock()
 | 
			
		||||
	defer t.mu.Unlock()
 | 
			
		||||
	for _, s := range t.subscribers {
 | 
			
		||||
		if s.visitor.MaybeUserID() != exceptUserID {
 | 
			
		||||
			// TODO: Shouldn't this log the IP for anonymous visitors? It was s.userID before my change.
 | 
			
		||||
			log.Tag(tagSubscribe).Field("topic", t.ID).Debug("Canceling subscriber %s", s.visitor.MaybeUserID())
 | 
			
		||||
		if s.userID != exceptUserID {
 | 
			
		||||
			log.
 | 
			
		||||
				Tag(tagSubscribe).
 | 
			
		||||
				Fields(log.Context{
 | 
			
		||||
					"message_topic": t.ID,
 | 
			
		||||
					"user_id":       s.userID,
 | 
			
		||||
				}).
 | 
			
		||||
				Debug("Canceling subscriber %s", s.userID)
 | 
			
		||||
			s.cancel()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -130,7 +132,7 @@ func (t *topic) subscribersCopy() map[int]*topicSubscriber {
 | 
			
		|||
	subscribers := make(map[int]*topicSubscriber)
 | 
			
		||||
	for k, sub := range t.subscribers {
 | 
			
		||||
		subscribers[k] = &topicSubscriber{
 | 
			
		||||
			visitor:    sub.visitor,
 | 
			
		||||
			userID:     sub.userID,
 | 
			
		||||
			subscriber: sub.subscriber,
 | 
			
		||||
			cancel:     sub.cancel,
 | 
			
		||||
		}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,6 +2,7 @@ package server
 | 
			
		|||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"heckel.io/ntfy/util"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
| 
						 | 
				
			
			@ -105,3 +106,11 @@ func withContext(r *http.Request, ctx map[contextKey]any) *http.Request {
 | 
			
		|||
	}
 | 
			
		||||
	return r.WithContext(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func fromContext[T any](r *http.Request, key contextKey) *T {
 | 
			
		||||
	t, ok := r.Context().Value(key).(*T)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		panic(fmt.Sprintf("cannot find key %v in request context", key))
 | 
			
		||||
	}
 | 
			
		||||
	return t
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue