1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2025-05-20 22:08:20 +02:00

Kill existing subscribers when topic is reserved

This commit is contained in:
binwiederhier 2023-01-23 14:05:41 -05:00
parent e82a2e518c
commit bce71cb196
5 changed files with 169 additions and 36 deletions
server

View file

@ -38,11 +38,13 @@ import (
TODO
--
- Reservation: Kill existing subscribers when topic is reserved (deadcade)
- Rate limiting: Sensitive endpoints (account/login/change-password/...)
- Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
- Reservation (UI): Ask for confirmation when removing reservation (deadcade)
- Reservation icons (UI)
- reservation table delete button: dialog "keep or delete messages?"
- UI: Flickering upgrade banner when logging in
- JS constants
races:
- v.user --> see publishSyncEventAsync() test
@ -63,11 +65,6 @@ Limits & rate limiting:
Make sure account endpoints make sense for admins
UI:
-
- reservation table delete button: dialog "keep or delete messages?"
- flicker of upgrade banner
- JS constants
Sync:
- sync problems with "deleteAfter=0" and "displayName="
@ -359,7 +356,7 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
log.Info("%s Connection closed with HTTP %d (ntfy error %d): %s", logHTTPPrefix(v, r), httpErr.HTTPCode, httpErr.Code, err.Error())
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
w.WriteHeader(httpErr.HTTPCode)
io.WriteString(w, httpErr.JSON()+"\n")
}
@ -461,7 +458,7 @@ func (s *Server) handleTopic(w http.ResponseWriter, r *http.Request, v *visitor)
unifiedpush := readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see PUT/POST too!
if unifiedpush {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
_, err := io.WriteString(w, `{"unifiedpush":{"version":1}}`+"\n")
return err
}
@ -538,7 +535,7 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
}
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
if r.Method == http.MethodGet {
f, err := os.Open(file)
if err != nil {
@ -969,14 +966,16 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
}
return nil
}
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
if poll {
return s.sendOldMessages(topics, since, scheduled, v, sub)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
subscriberIDs := make([]int, 0)
for _, t := range topics {
subscriberIDs = append(subscriberIDs, t.Subscribe(sub))
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel))
}
defer func() {
for i, subscriberID := range subscriberIDs {
@ -991,6 +990,8 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
}
for {
select {
case <-ctx.Done():
return nil
case <-r.Context().Done():
return nil
case <-time.After(s.config.KeepaliveInterval):
@ -1033,8 +1034,20 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
return err
}
defer conn.Close()
// Subscription connections can be canceled externally, see topic.CancelSubscribers
subscriberContext, cancel := context.WithCancel(context.Background())
defer cancel()
// Use errgroup to run WebSocket reader and writer in Go routines
var wlock sync.Mutex
g, ctx := errgroup.WithContext(context.Background())
g, gctx := errgroup.WithContext(context.Background())
g.Go(func() error {
<-subscriberContext.Done()
log.Trace("%s Cancel received, closing subscriber connection", logHTTPPrefix(v, r))
conn.Close()
return &websocket.CloseError{Code: websocket.CloseNormalClosure, Text: "subscription was canceled"}
})
g.Go(func() error {
pongWait := s.config.KeepaliveInterval + wsPongWait
conn.SetReadLimit(wsReadLimit)
@ -1050,6 +1063,11 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
if err != nil {
return err
}
select {
case <-gctx.Done():
return nil
default:
}
}
})
g.Go(func() error {
@ -1064,7 +1082,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
}
for {
select {
case <-ctx.Done():
case <-gctx.Done():
return nil
case <-time.After(s.config.KeepaliveInterval):
v.Keepalive()
@ -1085,13 +1103,13 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
}
return conn.WriteJSON(msg)
}
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
if poll {
return s.sendOldMessages(topics, since, scheduled, v, sub)
}
subscriberIDs := make([]int, 0)
for _, t := range topics {
subscriberIDs = append(subscriberIDs, t.Subscribe(sub))
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel))
}
defer func() {
for i, subscriberID := range subscriberIDs {
@ -1193,11 +1211,7 @@ func (s *Server) topicFromPath(path string) (*topic, error) {
if len(parts) < 2 {
return nil, errHTTPBadRequestTopicInvalid
}
topics, err := s.topicsFromIDs(parts[1])
if err != nil {
return nil, err
}
return topics[0], nil
return s.topicFromID(parts[1])
}
func (s *Server) topicsFromPath(path string) ([]*topic, string, error) {
@ -1232,6 +1246,14 @@ func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
return topics, nil
}
func (s *Server) topicFromID(id string) (*topic, error) {
topics, err := s.topicsFromIDs(id)
if err != nil {
return nil, err
}
return topics[0], nil
}
func (s *Server) execManager() {
log.Debug("Manager: Starting")
defer log.Debug("Manager: Finished")