mirror of
				https://github.com/binwiederhier/ntfy.git
				synced 2025-10-31 13:02:24 +01:00 
			
		
		
		
	Works
This commit is contained in:
		
							parent
							
								
									3eeeac2c13
								
							
						
					
					
						commit
						346d8d7967
					
				
					 8 changed files with 141 additions and 15 deletions
				
			
		|  | @ -585,9 +585,9 @@ func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error { | |||
| 	return writeMatrixDiscoveryResponse(w) | ||||
| } | ||||
| 
 | ||||
| func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) { | ||||
| 	t := fromContext[topic](r, contextTopic) | ||||
| 	vrate := fromContext[visitor](r, contextRateVisitor) | ||||
| func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, error) { | ||||
| 	t := fromContext[*topic](r, contextTopic) | ||||
| 	vrate := fromContext[*visitor](r, contextRateVisitor) | ||||
| 	body, err := util.Peek(r.Body, s.config.MessageLimit) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
|  | @ -670,7 +670,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes | |||
| } | ||||
| 
 | ||||
| func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error { | ||||
| 	m, err := s.handlePublishWithoutResponse(r, v) | ||||
| 	m, err := s.handlePublishInternal(r, v) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | @ -678,10 +678,14 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito | |||
| } | ||||
| 
 | ||||
| func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *visitor) error { | ||||
| 	_, err := s.handlePublishWithoutResponse(r, v) | ||||
| 	_, err := s.handlePublishInternal(r, v) | ||||
| 	if err != nil { | ||||
| 		if e, ok := err.(*errHTTP); ok && e.HTTPCode == errHTTPInsufficientStorageUnifiedPush.HTTPCode { | ||||
| 			return writeMatrixResponse(w, e.rejectedPushKey) | ||||
| 			topic := fromContext[*topic](r, contextTopic) | ||||
| 			pushKey := fromContext[string](r, contextMatrixPushKey) | ||||
| 			if time.Since(topic.LastAccess()) > matrixRejectPushKeyForUnifiedPushTopicWithoutRateVisitorAfter { | ||||
| 				return writeMatrixResponse(w, pushKey) | ||||
| 			} | ||||
| 		} | ||||
| 		return err | ||||
| 	} | ||||
|  | @ -1011,6 +1015,9 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * | |||
| 	w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests | ||||
| 	w.Header().Set("Content-Type", contentType+"; charset=utf-8")                    // Android/Volley client needs charset! | ||||
| 	if poll { | ||||
| 		for _, t := range topics { | ||||
| 			t.Keepalive() | ||||
| 		} | ||||
| 		return s.sendOldMessages(topics, since, scheduled, v, sub) | ||||
| 	} | ||||
| 	ctx, cancel := context.WithCancel(context.Background()) | ||||
|  | @ -1037,7 +1044,12 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * | |||
| 		case <-r.Context().Done(): | ||||
| 			return nil | ||||
| 		case <-time.After(s.config.KeepaliveInterval): | ||||
| 			logvr(v, r).Tag(tagSubscribe).Trace("Sending keepalive message") | ||||
| 			ev := logvr(v, r).Tag(tagSubscribe) | ||||
| 			if len(topics) == 1 { | ||||
| 				ev.With(topics[0]).Trace("Sending keepalive message to %s", topics[0].ID) | ||||
| 			} else { | ||||
| 				ev.Trace("Sending keepalive message to %d topics", len(topics)) | ||||
| 			} | ||||
| 			v.Keepalive() | ||||
| 			for _, t := range topics { | ||||
| 				t.Keepalive() | ||||
|  | @ -1154,6 +1166,9 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi | |||
| 	} | ||||
| 	w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests | ||||
| 	if poll { | ||||
| 		for _, t := range topics { | ||||
| 			t.Keepalive() | ||||
| 		} | ||||
| 		return s.sendOldMessages(topics, since, scheduled, v, sub) | ||||
| 	} | ||||
| 	subscriberIDs := make([]int, 0) | ||||
|  |  | |||
|  | @ -2,8 +2,8 @@ package server | |||
| 
 | ||||
| import ( | ||||
| 	"heckel.io/ntfy/log" | ||||
| 	"heckel.io/ntfy/util" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| func (s *Server) execManager() { | ||||
|  | @ -39,13 +39,13 @@ func (s *Server) execManager() { | |||
| 				ev := log.Tag(tagManager).With(t) | ||||
| 				if t.Stale() { | ||||
| 					if ev.IsTrace() { | ||||
| 						ev.Trace("- topic %s: Deleting stale topic (%d subscribers, accessed %s)", t.ID, subs, lastAccess.Format(time.RFC822)) | ||||
| 						ev.Trace("- topic %s: Deleting stale topic (%d subscribers, accessed %s)", t.ID, subs, util.FormatTime(lastAccess)) | ||||
| 					} | ||||
| 					emptyTopics++ | ||||
| 					delete(s.topics, t.ID) | ||||
| 				} else { | ||||
| 					if ev.IsTrace() { | ||||
| 						ev.Trace("- topic %s: %d subscribers, accessed %s", t.ID, subs, lastAccess.Format(time.RFC822)) | ||||
| 						ev.Trace("- topic %s: %d subscribers, accessed %s", t.ID, subs, util.FormatTime(lastAccess)) | ||||
| 					} | ||||
| 					subscribers += subs | ||||
| 				} | ||||
|  |  | |||
|  | @ -8,6 +8,7 @@ import ( | |||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| // Matrix Push Gateway / UnifiedPush / ntfy integration: | ||||
|  | @ -71,6 +72,14 @@ type matrixResponse struct { | |||
| 	Rejected []string `json:"rejected"` | ||||
| } | ||||
| 
 | ||||
| const ( | ||||
| 	// matrixRejectPushKeyForUnifiedPushTopicWithoutRateVisitorAfter is the time after which a Matrix response | ||||
| 	// will return an HTTP 200 with the push key (i.e. "rejected":["<pushkey>"]}), if no rate visitor has been set on | ||||
| 	// the topic. Rejecting the push key will instruct the Matrix server to invalidate the pushkey and stop sending | ||||
| 	// messages to it. See https://spec.matrix.org/v1.6/push-gateway-api/ | ||||
| 	matrixRejectPushKeyForUnifiedPushTopicWithoutRateVisitorAfter = 12 * time.Hour | ||||
| ) | ||||
| 
 | ||||
| // errMatrixPushkeyRejected represents an error when handing Matrix gateway messages | ||||
| // | ||||
| // If the push key is set, the app server will remove it and will never send messages using the same | ||||
|  | @ -126,7 +135,9 @@ func newRequestFromMatrixJSON(r *http.Request, baseURL string, messageLimit int) | |||
| 	if r.Header.Get("X-Forwarded-For") != "" { | ||||
| 		newRequest.Header.Set("X-Forwarded-For", r.Header.Get("X-Forwarded-For")) | ||||
| 	} | ||||
| 	newRequest.Header.Set("X-Matrix-Pushkey", pushKey) | ||||
| 	newRequest = withContext(newRequest, map[contextKey]any{ | ||||
| 		contextMatrixPushKey: pushKey, | ||||
| 	}) | ||||
| 	return newRequest, nil | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -11,6 +11,7 @@ type contextKey int | |||
| const ( | ||||
| 	contextRateVisitor contextKey = iota + 2586 | ||||
| 	contextTopic | ||||
| 	contextMatrixPushKey | ||||
| ) | ||||
| 
 | ||||
| func (s *Server) limitRequests(next handleFunc) handleFunc { | ||||
|  |  | |||
|  | @ -1172,6 +1172,56 @@ func TestServer_PublishEmailNoMailer_Fail(t *testing.T) { | |||
| 	require.Equal(t, 400, response.Code) | ||||
| } | ||||
| 
 | ||||
| func TestServer_PublishAndExpungeTopicAfter16Hours(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	s := newTestServer(t, newTestConfig(t)) | ||||
| 
 | ||||
| 	subFn := func(v *visitor, msg *message) error { | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	// Publish and check last access | ||||
| 	response := request(t, s, "POST", "/mytopic", "test", map[string]string{ | ||||
| 		"Cache": "no", | ||||
| 	}) | ||||
| 	require.Equal(t, 200, response.Code) | ||||
| 	require.True(t, s.topics["mytopic"].lastAccess.Unix() >= time.Now().Unix()-2) | ||||
| 	require.True(t, s.topics["mytopic"].lastAccess.Unix() <= time.Now().Unix()+2) | ||||
| 
 | ||||
| 	// Topic won't get pruned | ||||
| 	s.execManager() | ||||
| 	require.NotNil(t, s.topics["mytopic"]) | ||||
| 
 | ||||
| 	// Fudge with last access, but subscribe, and see that it won't get pruned (because of subscriber) | ||||
| 	subID := s.topics["mytopic"].Subscribe(subFn, "", func() {}) | ||||
| 	s.topics["mytopic"].lastAccess = time.Now().Add(-17 * time.Hour) | ||||
| 	s.execManager() | ||||
| 	require.NotNil(t, s.topics["mytopic"]) | ||||
| 
 | ||||
| 	// It'll finally get pruned now that there are no subscribers and last access is 17 hours ago | ||||
| 	s.topics["mytopic"].Unsubscribe(subID) | ||||
| 	s.execManager() | ||||
| 	require.Nil(t, s.topics["mytopic"]) | ||||
| } | ||||
| 
 | ||||
| func TestServer_TopicKeepaliveOnPoll(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 	s := newTestServer(t, newTestConfig(t)) | ||||
| 
 | ||||
| 	// Create topic by polling once | ||||
| 	response := request(t, s, "GET", "/mytopic/json?poll=1", "", nil) | ||||
| 	require.Equal(t, 200, response.Code) | ||||
| 
 | ||||
| 	// Mess with last access time | ||||
| 	s.topics["mytopic"].lastAccess = time.Now().Add(-17 * time.Hour) | ||||
| 
 | ||||
| 	// Poll again and check keepalive time | ||||
| 	response = request(t, s, "GET", "/mytopic/json?poll=1", "", nil) | ||||
| 	require.Equal(t, 200, response.Code) | ||||
| 	require.True(t, s.topics["mytopic"].lastAccess.Unix() >= time.Now().Unix()-2) | ||||
| 	require.True(t, s.topics["mytopic"].lastAccess.Unix() <= time.Now().Unix()+2) | ||||
| } | ||||
| 
 | ||||
| func TestServer_UnifiedPushDiscovery(t *testing.T) { | ||||
| 	s := newTestServer(t, newTestConfig(t)) | ||||
| 	response := request(t, s, "GET", "/mytopic?up=1", "", nil) | ||||
|  | @ -1301,6 +1351,32 @@ func TestServer_MatrixGateway_Push_Failure_NoSubscriber(t *testing.T) { | |||
| 	require.Equal(t, 50701, toHTTPError(t, response.Body.String()).Code) | ||||
| } | ||||
| 
 | ||||
| func TestServer_MatrixGateway_Push_Failure_NoSubscriber_After13Hours(t *testing.T) { | ||||
| 	c := newTestConfig(t) | ||||
| 	c.VisitorSubscriberRateLimiting = true | ||||
| 	s := newTestServer(t, c) | ||||
| 	notification := `{"notification":{"devices":[{"pushkey":"http://127.0.0.1:12345/mytopic?up=1"}]}}` | ||||
| 
 | ||||
| 	// No success if no rate visitor set (this also creates the topic in memory | ||||
| 	response := request(t, s, "POST", "/_matrix/push/v1/notify", notification, nil) | ||||
| 	require.Equal(t, 507, response.Code) | ||||
| 	require.Equal(t, 50701, toHTTPError(t, response.Body.String()).Code) | ||||
| 	require.Nil(t, s.topics["mytopic"].rateVisitor) | ||||
| 
 | ||||
| 	// Fake: This topic has been around for 13 hours without a rate visitor | ||||
| 	s.topics["mytopic"].lastAccess = time.Now().Add(-13 * time.Hour) | ||||
| 
 | ||||
| 	// Same request should now return HTTP 200 with a rejected pushkey | ||||
| 	response = request(t, s, "POST", "/_matrix/push/v1/notify", notification, nil) | ||||
| 	require.Equal(t, 200, response.Code) | ||||
| 	require.Equal(t, `{"rejected":["http://127.0.0.1:12345/mytopic?up=1"]}`, strings.TrimSpace(response.Body.String())) | ||||
| 
 | ||||
| 	// Slightly unrelated: Test that topic is pruned after 16 hours | ||||
| 	s.topics["mytopic"].lastAccess = time.Now().Add(-17 * time.Hour) | ||||
| 	s.execManager() | ||||
| 	require.Nil(t, s.topics["mytopic"]) | ||||
| } | ||||
| 
 | ||||
| func TestServer_MatrixGateway_Push_Failure_InvalidPushkey(t *testing.T) { | ||||
| 	s := newTestServer(t, newTestConfig(t)) | ||||
| 	notification := `{"notification":{"devices":[{"pushkey":"http://wrong-base-url.com/mytopic?up=1"}]}}` | ||||
|  |  | |||
|  | @ -2,13 +2,18 @@ package server | |||
| 
 | ||||
| import ( | ||||
| 	"heckel.io/ntfy/log" | ||||
| 	"heckel.io/ntfy/util" | ||||
| 	"math/rand" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	topicExpiryDuration = 6 * time.Hour | ||||
| 	// topicExpungeAfter defines how long a topic is active before it is removed from memory. | ||||
| 	// | ||||
| 	// This must be larger than matrixRejectPushKeyForUnifiedPushTopicWithoutRateVisitorAfter to give | ||||
| 	// time for more requests to come in, so that we can send a {"rejected":["<pushkey>"]} response back. | ||||
| 	topicExpungeAfter = 16 * time.Hour | ||||
| ) | ||||
| 
 | ||||
| // topic represents a channel to which subscribers can subscribe, and publishers | ||||
|  | @ -59,7 +64,13 @@ func (t *topic) Stale() bool { | |||
| 	if t.rateVisitor != nil && !t.rateVisitor.Stale() { | ||||
| 		return false | ||||
| 	} | ||||
| 	return len(t.subscribers) == 0 && time.Since(t.lastAccess) > topicExpiryDuration | ||||
| 	return len(t.subscribers) == 0 && time.Since(t.lastAccess) > topicExpungeAfter | ||||
| } | ||||
| 
 | ||||
| func (t *topic) LastAccess() time.Time { | ||||
| 	t.mu.RLock() | ||||
| 	defer t.mu.RUnlock() | ||||
| 	return t.lastAccess | ||||
| } | ||||
| 
 | ||||
| func (t *topic) SetRateVisitor(v *visitor) { | ||||
|  | @ -148,6 +159,7 @@ func (t *topic) Context() log.Context { | |||
| 	fields := map[string]any{ | ||||
| 		"topic":             t.ID, | ||||
| 		"topic_subscribers": len(t.subscribers), | ||||
| 		"topic_last_access": util.FormatTime(t.lastAccess), | ||||
| 	} | ||||
| 	if t.rateVisitor != nil { | ||||
| 		for k, v := range t.rateVisitor.Context() { | ||||
|  |  | |||
|  | @ -4,6 +4,7 @@ import ( | |||
| 	"github.com/stretchr/testify/require" | ||||
| 	"sync/atomic" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| func TestTopic_CancelSubscribers(t *testing.T) { | ||||
|  | @ -28,3 +29,13 @@ func TestTopic_CancelSubscribers(t *testing.T) { | |||
| 	require.True(t, canceled1.Load()) | ||||
| 	require.False(t, canceled2.Load()) | ||||
| } | ||||
| 
 | ||||
| func TestTopic_Keepalive(t *testing.T) { | ||||
| 	t.Parallel() | ||||
| 
 | ||||
| 	to := newTopic("mytopic") | ||||
| 	to.lastAccess = time.Now().Add(-1 * time.Hour) | ||||
| 	to.Keepalive() | ||||
| 	require.True(t, to.LastAccess().Unix() >= time.Now().Unix()-2) | ||||
| 	require.True(t, to.LastAccess().Unix() <= time.Now().Unix()+2) | ||||
| } | ||||
|  |  | |||
|  | @ -107,8 +107,8 @@ func withContext(r *http.Request, ctx map[contextKey]any) *http.Request { | |||
| 	return r.WithContext(c) | ||||
| } | ||||
| 
 | ||||
| func fromContext[T any](r *http.Request, key contextKey) *T { | ||||
| 	t, ok := r.Context().Value(key).(*T) | ||||
| func fromContext[T any](r *http.Request, key contextKey) T { | ||||
| 	t, ok := r.Context().Value(key).(T) | ||||
| 	if !ok { | ||||
| 		panic(fmt.Sprintf("cannot find key %v in request context", key)) | ||||
| 	} | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 binwiederhier
						binwiederhier