mirror of
https://github.com/binwiederhier/ntfy.git
synced 2025-05-30 02:15:40 +02:00
Firebase quota limit
This commit is contained in:
parent
8a81c8e95b
commit
8283b6be97
9 changed files with 180 additions and 119 deletions
server
|
@ -7,13 +7,11 @@ import (
|
|||
"embed"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
|
@ -221,7 +219,7 @@ func (s *Server) Run() error {
|
|||
}
|
||||
s.mu.Unlock()
|
||||
go s.runManager()
|
||||
go s.runAtSender()
|
||||
go s.runDelayedSender()
|
||||
go s.runFirebaseKeepaliver()
|
||||
|
||||
return <-errChan
|
||||
|
@ -435,7 +433,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
|
|||
}
|
||||
delayed := m.Time > time.Now().Unix()
|
||||
if !delayed {
|
||||
if err := t.Publish(m); err != nil {
|
||||
if err := t.Publish(v, m); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -465,7 +463,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
|
|||
}
|
||||
|
||||
func (s *Server) sendToFirebase(v *visitor, m *message) {
|
||||
if err := s.firebase(m); err != nil {
|
||||
if err := s.firebase(v, m); err != nil {
|
||||
log.Printf("[%s] FB - Unable to publish to Firebase: %v", v.ip, err.Error())
|
||||
}
|
||||
}
|
||||
|
@ -731,7 +729,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
|||
return err
|
||||
}
|
||||
var wlock sync.Mutex
|
||||
sub := func(msg *message) error {
|
||||
sub := func(v *visitor, msg *message) error {
|
||||
if !filters.Pass(msg) {
|
||||
return nil
|
||||
}
|
||||
|
@ -752,7 +750,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
|||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
|
||||
if poll {
|
||||
return s.sendOldMessages(topics, since, scheduled, sub)
|
||||
return s.sendOldMessages(topics, since, scheduled, v, sub)
|
||||
}
|
||||
subscriberIDs := make([]int, 0)
|
||||
for _, t := range topics {
|
||||
|
@ -763,10 +761,10 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
|||
topics[i].Unsubscribe(subscriberID) // Order!
|
||||
}
|
||||
}()
|
||||
if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message
|
||||
if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
|
||||
return err
|
||||
}
|
||||
if err := s.sendOldMessages(topics, since, scheduled, sub); err != nil {
|
||||
if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
|
||||
return err
|
||||
}
|
||||
for {
|
||||
|
@ -775,7 +773,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
|||
return nil
|
||||
case <-time.After(s.config.KeepaliveInterval):
|
||||
v.Keepalive()
|
||||
if err := sub(newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
|
||||
if err := sub(v, newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -849,7 +847,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
|||
}
|
||||
}
|
||||
})
|
||||
sub := func(msg *message) error {
|
||||
sub := func(v *visitor, msg *message) error {
|
||||
if !filters.Pass(msg) {
|
||||
return nil
|
||||
}
|
||||
|
@ -862,7 +860,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
|||
}
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||
if poll {
|
||||
return s.sendOldMessages(topics, since, scheduled, sub)
|
||||
return s.sendOldMessages(topics, since, scheduled, v, sub)
|
||||
}
|
||||
subscriberIDs := make([]int, 0)
|
||||
for _, t := range topics {
|
||||
|
@ -873,10 +871,10 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
|||
topics[i].Unsubscribe(subscriberID) // Order!
|
||||
}
|
||||
}()
|
||||
if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message
|
||||
if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
|
||||
return err
|
||||
}
|
||||
if err := s.sendOldMessages(topics, since, scheduled, sub); err != nil {
|
||||
if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
|
||||
return err
|
||||
}
|
||||
err = g.Wait()
|
||||
|
@ -900,7 +898,7 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu
|
|||
return
|
||||
}
|
||||
|
||||
func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, sub subscriber) error {
|
||||
func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error {
|
||||
if since.IsNone() {
|
||||
return nil
|
||||
}
|
||||
|
@ -910,7 +908,7 @@ func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled b
|
|||
return err
|
||||
}
|
||||
for _, m := range messages {
|
||||
if err := sub(m); err != nil {
|
||||
if err := sub(v, m); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -1057,23 +1055,7 @@ func (s *Server) updateStatsAndPrune() {
|
|||
}
|
||||
|
||||
func (s *Server) runSMTPServer() error {
|
||||
sub := func(m *message) error {
|
||||
url := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic)
|
||||
req, err := http.NewRequest("PUT", url, strings.NewReader(m.Message))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if m.Title != "" {
|
||||
req.Header.Set("Title", m.Title)
|
||||
}
|
||||
rr := httptest.NewRecorder()
|
||||
s.handle(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
return errors.New("error: " + rr.Body.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
s.smtpBackend = newMailBackend(s.config, sub)
|
||||
s.smtpBackend = newMailBackend(s.config, s.handle)
|
||||
s.smtpServer = smtp.NewServer(s.smtpBackend)
|
||||
s.smtpServer.Addr = s.config.SMTPServerListen
|
||||
s.smtpServer.Domain = s.config.SMTPServerDomain
|
||||
|
@ -1096,7 +1078,7 @@ func (s *Server) runManager() {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Server) runAtSender() {
|
||||
func (s *Server) runDelayedSender() {
|
||||
for {
|
||||
select {
|
||||
case <-time.After(s.config.AtSenderInterval):
|
||||
|
@ -1113,14 +1095,15 @@ func (s *Server) runFirebaseKeepaliver() {
|
|||
if s.firebase == nil {
|
||||
return
|
||||
}
|
||||
v := newVisitor(s.config, s.messageCache, "0.0.0.0")
|
||||
for {
|
||||
select {
|
||||
case <-time.After(s.config.FirebaseKeepaliveInterval):
|
||||
if err := s.firebase(newKeepaliveMessage(firebaseControlTopic)); err != nil {
|
||||
if err := s.firebase(v, newKeepaliveMessage(firebaseControlTopic)); err != nil {
|
||||
log.Printf("error sending Firebase keepalive message to %s: %s", firebaseControlTopic, err.Error())
|
||||
}
|
||||
case <-time.After(s.config.FirebasePollInterval):
|
||||
if err := s.firebase(newKeepaliveMessage(firebasePollTopic)); err != nil {
|
||||
if err := s.firebase(v, newKeepaliveMessage(firebasePollTopic)); err != nil {
|
||||
log.Printf("error sending Firebase keepalive message to %s: %s", firebasePollTopic, err.Error())
|
||||
}
|
||||
case <-s.closeChan:
|
||||
|
@ -1130,28 +1113,36 @@ func (s *Server) runFirebaseKeepaliver() {
|
|||
}
|
||||
|
||||
func (s *Server) sendDelayedMessages() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
messages, err := s.messageCache.MessagesDue()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, m := range messages {
|
||||
t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
|
||||
if ok {
|
||||
if err := t.Publish(m); err != nil {
|
||||
log.Printf("unable to publish message %s to topic %s: %v", m.ID, m.Topic, err.Error())
|
||||
}
|
||||
v := s.visitorFromIP("0.0.0.0") // FIXME: get message owner!!
|
||||
if err := s.sendDelayedMessage(v, m); err != nil {
|
||||
log.Printf("error sending delayed message: %s", err.Error())
|
||||
}
|
||||
if s.firebase != nil { // Firebase subscribers may not show up in topics map
|
||||
if err := s.firebase(m); err != nil {
|
||||
log.Printf("unable to publish to Firebase: %v", err.Error())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
|
||||
if ok {
|
||||
if err := t.Publish(v, m); err != nil {
|
||||
return fmt.Errorf("unable to publish message %s to topic %s: %v", m.ID, m.Topic, err.Error())
|
||||
}
|
||||
if err := s.messageCache.MarkPublished(m); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.firebase != nil { // Firebase subscribers may not show up in topics map
|
||||
if err := s.firebase(v, m); err != nil {
|
||||
return fmt.Errorf("unable to publish to Firebase: %v", err.Error())
|
||||
}
|
||||
}
|
||||
if err := s.messageCache.MarkPublished(m); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -1290,8 +1281,6 @@ func extractUserPass(r *http.Request) (username string, password string, ok bool
|
|||
// visitor creates or retrieves a rate.Limiter for the given visitor.
|
||||
// This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
|
||||
func (s *Server) visitor(r *http.Request) *visitor {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
remoteAddr := r.RemoteAddr
|
||||
ip, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
|
@ -1300,6 +1289,12 @@ func (s *Server) visitor(r *http.Request) *visitor {
|
|||
if s.config.BehindProxy && r.Header.Get("X-Forwarded-For") != "" {
|
||||
ip = r.Header.Get("X-Forwarded-For")
|
||||
}
|
||||
return s.visitorFromIP(ip)
|
||||
}
|
||||
|
||||
func (s *Server) visitorFromIP(ip string) *visitor {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
v, exists := s.visitors[ip]
|
||||
if !exists {
|
||||
s.visitors[ip] = newVisitor(s.config, s.messageCache, ip)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue