From 346d8d79671ca99f45e56e48e949eaed777b1d22 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Fri, 3 Mar 2023 22:22:07 -0500 Subject: [PATCH] Works --- server/server.go | 29 ++++++++++---- server/server_manager.go | 6 +-- server/server_matrix.go | 13 ++++++- server/server_middleware.go | 1 + server/server_test.go | 76 +++++++++++++++++++++++++++++++++++++ server/topic.go | 16 +++++++- server/topic_test.go | 11 ++++++ server/util.go | 4 +- 8 files changed, 141 insertions(+), 15 deletions(-) diff --git a/server/server.go b/server/server.go index aeaae04f..2397ba36 100644 --- a/server/server.go +++ b/server/server.go @@ -585,9 +585,9 @@ func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error { return writeMatrixDiscoveryResponse(w) } -func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) { - t := fromContext[topic](r, contextTopic) - vrate := fromContext[visitor](r, contextRateVisitor) +func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, error) { + t := fromContext[*topic](r, contextTopic) + vrate := fromContext[*visitor](r, contextRateVisitor) body, err := util.Peek(r.Body, s.config.MessageLimit) if err != nil { return nil, err @@ -670,7 +670,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes } func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error { - m, err := s.handlePublishWithoutResponse(r, v) + m, err := s.handlePublishInternal(r, v) if err != nil { return err } @@ -678,10 +678,14 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito } func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *visitor) error { - _, err := s.handlePublishWithoutResponse(r, v) + _, err := s.handlePublishInternal(r, v) if err != nil { if e, ok := err.(*errHTTP); ok && e.HTTPCode == errHTTPInsufficientStorageUnifiedPush.HTTPCode { - return writeMatrixResponse(w, e.rejectedPushKey) + topic := fromContext[*topic](r, contextTopic) + pushKey := fromContext[string](r, contextMatrixPushKey) + if time.Since(topic.LastAccess()) > matrixRejectPushKeyForUnifiedPushTopicWithoutRateVisitorAfter { + return writeMatrixResponse(w, pushKey) + } } return err } @@ -1011,6 +1015,9 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset! if poll { + for _, t := range topics { + t.Keepalive() + } return s.sendOldMessages(topics, since, scheduled, v, sub) } ctx, cancel := context.WithCancel(context.Background()) @@ -1037,7 +1044,12 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * case <-r.Context().Done(): return nil case <-time.After(s.config.KeepaliveInterval): - logvr(v, r).Tag(tagSubscribe).Trace("Sending keepalive message") + ev := logvr(v, r).Tag(tagSubscribe) + if len(topics) == 1 { + ev.With(topics[0]).Trace("Sending keepalive message to %s", topics[0].ID) + } else { + ev.Trace("Sending keepalive message to %d topics", len(topics)) + } v.Keepalive() for _, t := range topics { t.Keepalive() @@ -1154,6 +1166,9 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi } w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests if poll { + for _, t := range topics { + t.Keepalive() + } return s.sendOldMessages(topics, since, scheduled, v, sub) } subscriberIDs := make([]int, 0) diff --git a/server/server_manager.go b/server/server_manager.go index 7deab25f..35f2c1b0 100644 --- a/server/server_manager.go +++ b/server/server_manager.go @@ -2,8 +2,8 @@ package server import ( "heckel.io/ntfy/log" + "heckel.io/ntfy/util" "strings" - "time" ) func (s *Server) execManager() { @@ -39,13 +39,13 @@ func (s *Server) execManager() { ev := log.Tag(tagManager).With(t) if t.Stale() { if ev.IsTrace() { - ev.Trace("- topic %s: Deleting stale topic (%d subscribers, accessed %s)", t.ID, subs, lastAccess.Format(time.RFC822)) + ev.Trace("- topic %s: Deleting stale topic (%d subscribers, accessed %s)", t.ID, subs, util.FormatTime(lastAccess)) } emptyTopics++ delete(s.topics, t.ID) } else { if ev.IsTrace() { - ev.Trace("- topic %s: %d subscribers, accessed %s", t.ID, subs, lastAccess.Format(time.RFC822)) + ev.Trace("- topic %s: %d subscribers, accessed %s", t.ID, subs, util.FormatTime(lastAccess)) } subscribers += subs } diff --git a/server/server_matrix.go b/server/server_matrix.go index bd96f43c..704c624b 100644 --- a/server/server_matrix.go +++ b/server/server_matrix.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "strings" + "time" ) // Matrix Push Gateway / UnifiedPush / ntfy integration: @@ -71,6 +72,14 @@ type matrixResponse struct { Rejected []string `json:"rejected"` } +const ( + // matrixRejectPushKeyForUnifiedPushTopicWithoutRateVisitorAfter is the time after which a Matrix response + // will return an HTTP 200 with the push key (i.e. "rejected":[""]}), if no rate visitor has been set on + // the topic. Rejecting the push key will instruct the Matrix server to invalidate the pushkey and stop sending + // messages to it. See https://spec.matrix.org/v1.6/push-gateway-api/ + matrixRejectPushKeyForUnifiedPushTopicWithoutRateVisitorAfter = 12 * time.Hour +) + // errMatrixPushkeyRejected represents an error when handing Matrix gateway messages // // If the push key is set, the app server will remove it and will never send messages using the same @@ -126,7 +135,9 @@ func newRequestFromMatrixJSON(r *http.Request, baseURL string, messageLimit int) if r.Header.Get("X-Forwarded-For") != "" { newRequest.Header.Set("X-Forwarded-For", r.Header.Get("X-Forwarded-For")) } - newRequest.Header.Set("X-Matrix-Pushkey", pushKey) + newRequest = withContext(newRequest, map[contextKey]any{ + contextMatrixPushKey: pushKey, + }) return newRequest, nil } diff --git a/server/server_middleware.go b/server/server_middleware.go index 750a02e0..5c83cf70 100644 --- a/server/server_middleware.go +++ b/server/server_middleware.go @@ -11,6 +11,7 @@ type contextKey int const ( contextRateVisitor contextKey = iota + 2586 contextTopic + contextMatrixPushKey ) func (s *Server) limitRequests(next handleFunc) handleFunc { diff --git a/server/server_test.go b/server/server_test.go index 707c7d88..2ca4f983 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1172,6 +1172,56 @@ func TestServer_PublishEmailNoMailer_Fail(t *testing.T) { require.Equal(t, 400, response.Code) } +func TestServer_PublishAndExpungeTopicAfter16Hours(t *testing.T) { + t.Parallel() + s := newTestServer(t, newTestConfig(t)) + + subFn := func(v *visitor, msg *message) error { + return nil + } + + // Publish and check last access + response := request(t, s, "POST", "/mytopic", "test", map[string]string{ + "Cache": "no", + }) + require.Equal(t, 200, response.Code) + require.True(t, s.topics["mytopic"].lastAccess.Unix() >= time.Now().Unix()-2) + require.True(t, s.topics["mytopic"].lastAccess.Unix() <= time.Now().Unix()+2) + + // Topic won't get pruned + s.execManager() + require.NotNil(t, s.topics["mytopic"]) + + // Fudge with last access, but subscribe, and see that it won't get pruned (because of subscriber) + subID := s.topics["mytopic"].Subscribe(subFn, "", func() {}) + s.topics["mytopic"].lastAccess = time.Now().Add(-17 * time.Hour) + s.execManager() + require.NotNil(t, s.topics["mytopic"]) + + // It'll finally get pruned now that there are no subscribers and last access is 17 hours ago + s.topics["mytopic"].Unsubscribe(subID) + s.execManager() + require.Nil(t, s.topics["mytopic"]) +} + +func TestServer_TopicKeepaliveOnPoll(t *testing.T) { + t.Parallel() + s := newTestServer(t, newTestConfig(t)) + + // Create topic by polling once + response := request(t, s, "GET", "/mytopic/json?poll=1", "", nil) + require.Equal(t, 200, response.Code) + + // Mess with last access time + s.topics["mytopic"].lastAccess = time.Now().Add(-17 * time.Hour) + + // Poll again and check keepalive time + response = request(t, s, "GET", "/mytopic/json?poll=1", "", nil) + require.Equal(t, 200, response.Code) + require.True(t, s.topics["mytopic"].lastAccess.Unix() >= time.Now().Unix()-2) + require.True(t, s.topics["mytopic"].lastAccess.Unix() <= time.Now().Unix()+2) +} + func TestServer_UnifiedPushDiscovery(t *testing.T) { s := newTestServer(t, newTestConfig(t)) response := request(t, s, "GET", "/mytopic?up=1", "", nil) @@ -1301,6 +1351,32 @@ func TestServer_MatrixGateway_Push_Failure_NoSubscriber(t *testing.T) { require.Equal(t, 50701, toHTTPError(t, response.Body.String()).Code) } +func TestServer_MatrixGateway_Push_Failure_NoSubscriber_After13Hours(t *testing.T) { + c := newTestConfig(t) + c.VisitorSubscriberRateLimiting = true + s := newTestServer(t, c) + notification := `{"notification":{"devices":[{"pushkey":"http://127.0.0.1:12345/mytopic?up=1"}]}}` + + // No success if no rate visitor set (this also creates the topic in memory + response := request(t, s, "POST", "/_matrix/push/v1/notify", notification, nil) + require.Equal(t, 507, response.Code) + require.Equal(t, 50701, toHTTPError(t, response.Body.String()).Code) + require.Nil(t, s.topics["mytopic"].rateVisitor) + + // Fake: This topic has been around for 13 hours without a rate visitor + s.topics["mytopic"].lastAccess = time.Now().Add(-13 * time.Hour) + + // Same request should now return HTTP 200 with a rejected pushkey + response = request(t, s, "POST", "/_matrix/push/v1/notify", notification, nil) + require.Equal(t, 200, response.Code) + require.Equal(t, `{"rejected":["http://127.0.0.1:12345/mytopic?up=1"]}`, strings.TrimSpace(response.Body.String())) + + // Slightly unrelated: Test that topic is pruned after 16 hours + s.topics["mytopic"].lastAccess = time.Now().Add(-17 * time.Hour) + s.execManager() + require.Nil(t, s.topics["mytopic"]) +} + func TestServer_MatrixGateway_Push_Failure_InvalidPushkey(t *testing.T) { s := newTestServer(t, newTestConfig(t)) notification := `{"notification":{"devices":[{"pushkey":"http://wrong-base-url.com/mytopic?up=1"}]}}` diff --git a/server/topic.go b/server/topic.go index 23613683..f743a9dd 100644 --- a/server/topic.go +++ b/server/topic.go @@ -2,13 +2,18 @@ package server import ( "heckel.io/ntfy/log" + "heckel.io/ntfy/util" "math/rand" "sync" "time" ) const ( - topicExpiryDuration = 6 * time.Hour + // topicExpungeAfter defines how long a topic is active before it is removed from memory. + // + // This must be larger than matrixRejectPushKeyForUnifiedPushTopicWithoutRateVisitorAfter to give + // time for more requests to come in, so that we can send a {"rejected":[""]} response back. + topicExpungeAfter = 16 * time.Hour ) // topic represents a channel to which subscribers can subscribe, and publishers @@ -59,7 +64,13 @@ func (t *topic) Stale() bool { if t.rateVisitor != nil && !t.rateVisitor.Stale() { return false } - return len(t.subscribers) == 0 && time.Since(t.lastAccess) > topicExpiryDuration + return len(t.subscribers) == 0 && time.Since(t.lastAccess) > topicExpungeAfter +} + +func (t *topic) LastAccess() time.Time { + t.mu.RLock() + defer t.mu.RUnlock() + return t.lastAccess } func (t *topic) SetRateVisitor(v *visitor) { @@ -148,6 +159,7 @@ func (t *topic) Context() log.Context { fields := map[string]any{ "topic": t.ID, "topic_subscribers": len(t.subscribers), + "topic_last_access": util.FormatTime(t.lastAccess), } if t.rateVisitor != nil { for k, v := range t.rateVisitor.Context() { diff --git a/server/topic_test.go b/server/topic_test.go index cab2918a..b22bad55 100644 --- a/server/topic_test.go +++ b/server/topic_test.go @@ -4,6 +4,7 @@ import ( "github.com/stretchr/testify/require" "sync/atomic" "testing" + "time" ) func TestTopic_CancelSubscribers(t *testing.T) { @@ -28,3 +29,13 @@ func TestTopic_CancelSubscribers(t *testing.T) { require.True(t, canceled1.Load()) require.False(t, canceled2.Load()) } + +func TestTopic_Keepalive(t *testing.T) { + t.Parallel() + + to := newTopic("mytopic") + to.lastAccess = time.Now().Add(-1 * time.Hour) + to.Keepalive() + require.True(t, to.LastAccess().Unix() >= time.Now().Unix()-2) + require.True(t, to.LastAccess().Unix() <= time.Now().Unix()+2) +} diff --git a/server/util.go b/server/util.go index 75810b59..c719118b 100644 --- a/server/util.go +++ b/server/util.go @@ -107,8 +107,8 @@ func withContext(r *http.Request, ctx map[contextKey]any) *http.Request { return r.WithContext(c) } -func fromContext[T any](r *http.Request, key contextKey) *T { - t, ok := r.Context().Value(key).(*T) +func fromContext[T any](r *http.Request, key contextKey) T { + t, ok := r.Context().Value(key).(T) if !ok { panic(fmt.Sprintf("cannot find key %v in request context", key)) }