mirror of
https://github.com/binwiederhier/ntfy.git
synced 2025-05-20 22:08:20 +02:00
Kill existing subscribers when topic is reserved
This commit is contained in:
parent
e82a2e518c
commit
bce71cb196
5 changed files with 169 additions and 36 deletions
server
|
@ -38,11 +38,13 @@ import (
|
|||
TODO
|
||||
--
|
||||
|
||||
- Reservation: Kill existing subscribers when topic is reserved (deadcade)
|
||||
- Rate limiting: Sensitive endpoints (account/login/change-password/...)
|
||||
- Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
|
||||
- Reservation (UI): Ask for confirmation when removing reservation (deadcade)
|
||||
- Reservation icons (UI)
|
||||
- reservation table delete button: dialog "keep or delete messages?"
|
||||
- UI: Flickering upgrade banner when logging in
|
||||
- JS constants
|
||||
|
||||
races:
|
||||
- v.user --> see publishSyncEventAsync() test
|
||||
|
@ -63,11 +65,6 @@ Limits & rate limiting:
|
|||
Make sure account endpoints make sense for admins
|
||||
|
||||
|
||||
UI:
|
||||
-
|
||||
- reservation table delete button: dialog "keep or delete messages?"
|
||||
- flicker of upgrade banner
|
||||
- JS constants
|
||||
Sync:
|
||||
- sync problems with "deleteAfter=0" and "displayName="
|
||||
|
||||
|
@ -359,7 +356,7 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
|
|||
log.Info("%s Connection closed with HTTP %d (ntfy error %d): %s", logHTTPPrefix(v, r), httpErr.HTTPCode, httpErr.Code, err.Error())
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||
w.WriteHeader(httpErr.HTTPCode)
|
||||
io.WriteString(w, httpErr.JSON()+"\n")
|
||||
}
|
||||
|
@ -461,7 +458,7 @@ func (s *Server) handleTopic(w http.ResponseWriter, r *http.Request, v *visitor)
|
|||
unifiedpush := readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see PUT/POST too!
|
||||
if unifiedpush {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||
_, err := io.WriteString(w, `{"unifiedpush":{"version":1}}`+"\n")
|
||||
return err
|
||||
}
|
||||
|
@ -538,7 +535,7 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
|
|||
}
|
||||
}
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||
if r.Method == http.MethodGet {
|
||||
f, err := os.Open(file)
|
||||
if err != nil {
|
||||
|
@ -969,14 +966,16 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
|||
}
|
||||
return nil
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
|
||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
|
||||
if poll {
|
||||
return s.sendOldMessages(topics, since, scheduled, v, sub)
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
subscriberIDs := make([]int, 0)
|
||||
for _, t := range topics {
|
||||
subscriberIDs = append(subscriberIDs, t.Subscribe(sub))
|
||||
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel))
|
||||
}
|
||||
defer func() {
|
||||
for i, subscriberID := range subscriberIDs {
|
||||
|
@ -991,6 +990,8 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
|||
}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-r.Context().Done():
|
||||
return nil
|
||||
case <-time.After(s.config.KeepaliveInterval):
|
||||
|
@ -1033,8 +1034,20 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
|||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Subscription connections can be canceled externally, see topic.CancelSubscribers
|
||||
subscriberContext, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Use errgroup to run WebSocket reader and writer in Go routines
|
||||
var wlock sync.Mutex
|
||||
g, ctx := errgroup.WithContext(context.Background())
|
||||
g, gctx := errgroup.WithContext(context.Background())
|
||||
g.Go(func() error {
|
||||
<-subscriberContext.Done()
|
||||
log.Trace("%s Cancel received, closing subscriber connection", logHTTPPrefix(v, r))
|
||||
conn.Close()
|
||||
return &websocket.CloseError{Code: websocket.CloseNormalClosure, Text: "subscription was canceled"}
|
||||
})
|
||||
g.Go(func() error {
|
||||
pongWait := s.config.KeepaliveInterval + wsPongWait
|
||||
conn.SetReadLimit(wsReadLimit)
|
||||
|
@ -1050,6 +1063,11 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
select {
|
||||
case <-gctx.Done():
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
}
|
||||
})
|
||||
g.Go(func() error {
|
||||
|
@ -1064,7 +1082,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
|||
}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-gctx.Done():
|
||||
return nil
|
||||
case <-time.After(s.config.KeepaliveInterval):
|
||||
v.Keepalive()
|
||||
|
@ -1085,13 +1103,13 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
|||
}
|
||||
return conn.WriteJSON(msg)
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||
if poll {
|
||||
return s.sendOldMessages(topics, since, scheduled, v, sub)
|
||||
}
|
||||
subscriberIDs := make([]int, 0)
|
||||
for _, t := range topics {
|
||||
subscriberIDs = append(subscriberIDs, t.Subscribe(sub))
|
||||
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel))
|
||||
}
|
||||
defer func() {
|
||||
for i, subscriberID := range subscriberIDs {
|
||||
|
@ -1193,11 +1211,7 @@ func (s *Server) topicFromPath(path string) (*topic, error) {
|
|||
if len(parts) < 2 {
|
||||
return nil, errHTTPBadRequestTopicInvalid
|
||||
}
|
||||
topics, err := s.topicsFromIDs(parts[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return topics[0], nil
|
||||
return s.topicFromID(parts[1])
|
||||
}
|
||||
|
||||
func (s *Server) topicsFromPath(path string) ([]*topic, string, error) {
|
||||
|
@ -1232,6 +1246,14 @@ func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
|
|||
return topics, nil
|
||||
}
|
||||
|
||||
func (s *Server) topicFromID(id string) (*topic, error) {
|
||||
topics, err := s.topicsFromIDs(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return topics[0], nil
|
||||
}
|
||||
|
||||
func (s *Server) execManager() {
|
||||
log.Debug("Manager: Starting")
|
||||
defer log.Debug("Manager: Finished")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue