1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2025-06-14 08:33:20 +02:00

Merge branch 'main' into logging

This commit is contained in:
Philipp Heckel 2022-05-31 23:39:11 -04:00
commit a04cf5fcb6
20 changed files with 917 additions and 715 deletions
server

View file

@ -7,13 +7,11 @@ import (
"embed"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"heckel.io/ntfy/log"
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path"
@ -34,22 +32,22 @@ import (
// Server is the main server, providing the UI and API for ntfy
type Server struct {
config *Config
httpServer *http.Server
httpsServer *http.Server
unixListener net.Listener
smtpServer *smtp.Server
smtpBackend *smtpBackend
topics map[string]*topic
visitors map[string]*visitor
firebase subscriber
mailer mailer
messages int64
auth auth.Auther
messageCache *messageCache
fileCache *fileCache
closeChan chan bool
mu sync.Mutex
config *Config
httpServer *http.Server
httpsServer *http.Server
unixListener net.Listener
smtpServer *smtp.Server
smtpBackend *smtpBackend
topics map[string]*topic
visitors map[string]*visitor
firebaseClient *firebaseClient
mailer mailer
messages int64
auth auth.Auther
messageCache *messageCache
fileCache *fileCache
closeChan chan bool
mu sync.Mutex
}
// handleFunc extends the normal http.HandlerFunc to be able to easily return errors
@ -136,23 +134,23 @@ func New(conf *Config) (*Server, error) {
return nil, err
}
}
var firebaseSubscriber subscriber
var firebaseClient *firebaseClient
if conf.FirebaseKeyFile != "" {
var err error
firebaseSubscriber, err = createFirebaseSubscriber(conf.FirebaseKeyFile, auther)
sender, err := newFirebaseSender(conf.FirebaseKeyFile)
if err != nil {
return nil, err
}
firebaseClient = newFirebaseClient(sender, auther)
}
return &Server{
config: conf,
messageCache: messageCache,
fileCache: fileCache,
firebase: firebaseSubscriber,
mailer: mailer,
topics: topics,
auth: auther,
visitors: make(map[string]*visitor),
config: conf,
messageCache: messageCache,
fileCache: fileCache,
firebaseClient: firebaseClient,
mailer: mailer,
topics: topics,
auth: auther,
visitors: make(map[string]*visitor),
}, nil
}
@ -221,7 +219,7 @@ func (s *Server) Run() error {
}
s.mu.Unlock()
go s.runManager()
go s.runDelaySender()
go s.runDelayedSender()
go s.runFirebaseKeepaliver()
return <-errChan
@ -439,17 +437,17 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
log.Debug("[%s] %s %s: ev=%s, body=%d bytes, delayed=%t, fb=%t, cache=%t, up=%t, email=%s",
v.ip, r.Method, r.URL.Path, m.Event, len(body.PeekedBytes), delayed, firebase, cache, unifiedpush, email)
if !delayed {
if err := t.Publish(m); err != nil {
if err := t.Publish(v, m); err != nil {
return err
}
}
if s.firebase != nil && firebase && !delayed {
if s.firebaseClient != nil && firebase && !delayed {
go s.sendToFirebase(v, m)
}
if s.mailer != nil && email != "" && !delayed {
go s.sendEmail(v, m, email)
}
if s.config.UpstreamBaseURL != "" {
if s.config.UpstreamBaseURL != "" && !delayed {
go s.forwardPollRequest(v, m)
}
if cache {
@ -469,7 +467,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
}
func (s *Server) sendToFirebase(v *visitor, m *message) {
if err := s.firebase(m); err != nil {
if err := s.firebaseClient.Send(v, m); err != nil {
log.Warn("[%s] FB - Unable to publish to Firebase: %v", v.ip, err.Error())
}
}
@ -490,7 +488,10 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) {
return
}
req.Header.Set("X-Poll-ID", m.ID)
response, err := http.DefaultClient.Do(req)
var httpClient = &http.Client{
Timeout: time.Second * 10,
}
response, err := httpClient.Do(req)
if err != nil {
log.Warn("[%s] FWD - Unable to forward poll request: %v", v.ip, err.Error())
return
@ -572,6 +573,7 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca
return false, false, "", false, errHTTPBadRequestDelayTooLarge
}
m.Time = delay.Unix()
m.Sender = v.ip // Important for rate limiting
}
actionsStr := readParam(r, "x-actions", "actions", "action")
if actionsStr != "" {
@ -667,7 +669,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
m.Attachment = &attachment{}
}
var ext string
m.Attachment.Owner = v.ip // Important for attachment rate limiting
m.Sender = v.ip // Important for attachment rate limiting
m.Attachment.Expires = time.Now().Add(s.config.AttachmentExpiryDuration).Unix()
m.Attachment.Type, ext = util.DetectContentType(body.PeekedBytes, m.Attachment.Name)
m.Attachment.URL = fmt.Sprintf("%s/file/%s%s", s.config.BaseURL, m.ID, ext)
@ -735,7 +737,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
return err
}
var wlock sync.Mutex
sub := func(msg *message) error {
sub := func(v *visitor, msg *message) error {
if !filters.Pass(msg) {
return nil
}
@ -756,7 +758,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
if poll {
return s.sendOldMessages(topics, since, scheduled, sub)
return s.sendOldMessages(topics, since, scheduled, v, sub)
}
subscriberIDs := make([]int, 0)
for _, t := range topics {
@ -767,10 +769,10 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
topics[i].Unsubscribe(subscriberID) // Order!
}
}()
if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message
if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
return err
}
if err := s.sendOldMessages(topics, since, scheduled, sub); err != nil {
if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
return err
}
for {
@ -779,7 +781,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
return nil
case <-time.After(s.config.KeepaliveInterval):
v.Keepalive()
if err := sub(newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
if err := sub(v, newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
return err
}
}
@ -853,7 +855,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
}
}
})
sub := func(msg *message) error {
sub := func(v *visitor, msg *message) error {
if !filters.Pass(msg) {
return nil
}
@ -866,7 +868,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
}
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
if poll {
return s.sendOldMessages(topics, since, scheduled, sub)
return s.sendOldMessages(topics, since, scheduled, v, sub)
}
subscriberIDs := make([]int, 0)
for _, t := range topics {
@ -877,10 +879,10 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
topics[i].Unsubscribe(subscriberID) // Order!
}
}()
if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message
if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
return err
}
if err := s.sendOldMessages(topics, since, scheduled, sub); err != nil {
if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
return err
}
err = g.Wait()
@ -904,7 +906,7 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu
return
}
func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, sub subscriber) error {
func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error {
if since.IsNone() {
return nil
}
@ -914,7 +916,7 @@ func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled b
return err
}
for _, m := range messages {
if err := sub(m); err != nil {
if err := sub(v, m); err != nil {
return err
}
}
@ -1061,23 +1063,7 @@ func (s *Server) updateStatsAndPrune() {
}
func (s *Server) runSMTPServer() error {
sub := func(m *message) error {
url := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic)
req, err := http.NewRequest("PUT", url, strings.NewReader(m.Message))
if err != nil {
return err
}
if m.Title != "" {
req.Header.Set("Title", m.Title)
}
rr := httptest.NewRecorder()
s.handle(rr, req)
if rr.Code != http.StatusOK {
return errors.New("error: " + rr.Body.String())
}
return nil
}
s.smtpBackend = newMailBackend(s.config, sub)
s.smtpBackend = newMailBackend(s.config, s.handle)
s.smtpServer = smtp.NewServer(s.smtpBackend)
s.smtpServer.Addr = s.config.SMTPServerListen
s.smtpServer.Domain = s.config.SMTPServerDomain
@ -1100,10 +1086,10 @@ func (s *Server) runManager() {
}
}
func (s *Server) runDelaySender() {
func (s *Server) runDelayedSender() {
for {
select {
case <-time.After(s.config.AtSenderInterval):
case <-time.After(s.config.DelayedSenderInterval):
if err := s.sendDelayedMessages(); err != nil {
log.Warn("error sending scheduled messages: %s", err.Error())
}
@ -1114,19 +1100,16 @@ func (s *Server) runDelaySender() {
}
func (s *Server) runFirebaseKeepaliver() {
if s.firebase == nil {
if s.firebaseClient == nil {
return
}
v := newVisitor(s.config, s.messageCache, "0.0.0.0") // Background process, not a real visitor
for {
select {
case <-time.After(s.config.FirebaseKeepaliveInterval):
if err := s.firebase(newKeepaliveMessage(firebaseControlTopic)); err != nil {
log.Info("error sending Firebase keepalive message to %s: %s", firebaseControlTopic, err.Error())
}
s.sendToFirebase(v, newKeepaliveMessage(firebaseControlTopic))
case <-time.After(s.config.FirebasePollInterval):
if err := s.firebase(newKeepaliveMessage(firebasePollTopic)); err != nil {
log.Info("error sending Firebase keepalive message to %s: %s", firebasePollTopic, err.Error())
}
s.sendToFirebase(v, newKeepaliveMessage(firebasePollTopic))
case <-s.closeChan:
return
}
@ -1134,27 +1117,39 @@ func (s *Server) runFirebaseKeepaliver() {
}
func (s *Server) sendDelayedMessages() error {
s.mu.Lock()
defer s.mu.Unlock()
messages, err := s.messageCache.MessagesDue()
if err != nil {
return err
}
for _, m := range messages {
t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
if ok {
if err := t.Publish(m); err != nil {
log.Info("unable to publish message %s to topic %s: %v", m.ID, m.Topic, err.Error())
v := s.visitorFromIP(m.Sender)
if err := s.sendDelayedMessage(v, m); err != nil {
log.Warn("error sending delayed message: %s", err.Error())
}
}
return nil
}
func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
if ok {
go func() {
// We do not rate-limit messages here, since we've rate limited them in the PUT/POST handler
if err := t.Publish(v, m); err != nil {
log.Warn("unable to publish message %s to topic %s: %v", m.ID, m.Topic, err.Error())
}
}
if s.firebase != nil { // Firebase subscribers may not show up in topics map
if err := s.firebase(m); err != nil {
log.Info("unable to publish to Firebase: %v", err.Error())
}
}
if err := s.messageCache.MarkPublished(m); err != nil {
return err
}
}()
}
if s.firebaseClient != nil { // Firebase subscribers may not show up in topics map
go s.sendToFirebase(v, m)
}
if s.config.UpstreamBaseURL != "" {
go s.forwardPollRequest(v, m)
}
if err := s.messageCache.MarkPublished(m); err != nil {
return err
}
return nil
}
@ -1294,8 +1289,6 @@ func extractUserPass(r *http.Request) (username string, password string, ok bool
// visitor creates or retrieves a rate.Limiter for the given visitor.
// This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
func (s *Server) visitor(r *http.Request) *visitor {
s.mu.Lock()
defer s.mu.Unlock()
remoteAddr := r.RemoteAddr
ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
@ -1304,6 +1297,12 @@ func (s *Server) visitor(r *http.Request) *visitor {
if s.config.BehindProxy && r.Header.Get("X-Forwarded-For") != "" {
ip = r.Header.Get("X-Forwarded-For")
}
return s.visitorFromIP(ip)
}
func (s *Server) visitorFromIP(ip string) *visitor {
s.mu.Lock()
defer s.mu.Unlock()
v, exists := s.visitors[ip]
if !exists {
s.visitors[ip] = newVisitor(s.config, s.messageCache, ip)