1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2025-06-02 11:30:35 +02:00

Merge branch 'main' into user-account

This commit is contained in:
binwiederhier 2023-02-22 19:22:47 -05:00
commit 4ab450309f
81 changed files with 4094 additions and 13687 deletions
server

View file

@ -33,16 +33,6 @@ import (
"heckel.io/ntfy/util"
)
/*
- MEDIUM fail2ban to work with ntfy log not nginx log
- HIGH Docs
- tiers
- api
- tokens
*/
// Server is the main server, providing the UI and API for ntfy
type Server struct {
config *Config
@ -56,11 +46,11 @@ type Server struct {
visitors map[string]*visitor // ip:<ip> or user:<user>
firebaseClient *firebaseClient
messages int64
userManager *user.Manager // Might be nil!
messageCache *messageCache // Database that stores the messages
fileCache *fileCache // File system based cache that stores attachments
stripe stripeAPI // Stripe API, can be replaced with a mock
priceCache *util.LookupCache[map[string]string] // Stripe price ID -> formatted price
userManager *user.Manager // Might be nil!
messageCache *messageCache // Database that stores the messages
fileCache *fileCache // File system based cache that stores attachments
stripe stripeAPI // Stripe API, can be replaced with a mock
priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
closeChan chan bool
mu sync.Mutex
}
@ -134,24 +124,6 @@ const (
wsPongWait = 15 * time.Second
)
// Log tags
const (
tagStartup = "startup"
tagPublish = "publish"
tagSubscribe = "subscribe"
tagFirebase = "firebase"
tagSMTP = "smtp" // Receive email
tagEmail = "email" // Send email
tagFileCache = "file_cache"
tagMessageCache = "message_cache"
tagStripe = "stripe"
tagAccount = "account"
tagManager = "manager"
tagResetter = "resetter"
tagWebsocket = "websocket"
tagMatrix = "matrix"
)
// New instantiates a new Server. It creates the cache and adds a Firebase
// subscriber (if configured).
func New(conf *Config) (*Server, error) {
@ -327,11 +299,11 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
s.handleError(w, r, v, err)
return
}
if log.IsTrace() {
logvr(v, r).Field("http_request", renderHTTPRequest(r)).Trace("HTTP request started")
} else if log.IsDebug() {
logvr(v, r).Debug("HTTP request started")
ev := logvr(v, r)
if ev.IsTrace() {
ev.Field("http_request", renderHTTPRequest(r)).Trace("HTTP request started")
} else if logvr(v, r).IsDebug() {
ev.Debug("HTTP request started")
}
logvr(v, r).
Timing(func() {
@ -344,8 +316,12 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) handleError(w http.ResponseWriter, r *http.Request, v *visitor, err error) {
httpErr, ok := err.(*errHTTP)
if !ok {
httpErr = errHTTPInternalError
}
isNormalError := strings.Contains(err.Error(), "i/o timeout") || util.Contains([]int{http.StatusNotFound, http.StatusBadRequest, http.StatusTooManyRequests, http.StatusUnauthorized}, httpErr.HTTPCode)
if websocket.IsWebSocketUpgrade(r) {
isNormalError := strings.Contains(err.Error(), "i/o timeout")
if isNormalError {
logvr(v, r).Tag(tagWebsocket).Err(err).Fields(websocketErrorContext(err)).Debug("WebSocket error (this error is okay, it happens a lot): %s", err.Error())
} else {
@ -354,22 +330,15 @@ func (s *Server) handleError(w http.ResponseWriter, r *http.Request, v *visitor,
return // Do not attempt to write to upgraded connection
}
if matrixErr, ok := err.(*errMatrix); ok {
writeMatrixError(w, r, v, matrixErr)
if err := writeMatrixError(w, r, v, matrixErr); err != nil {
logvr(v, r).Tag(tagMatrix).Err(err).Debug("Writing Matrix error failed")
}
return
}
httpErr, ok := err.(*errHTTP)
if !ok {
httpErr = errHTTPInternalError
}
isNormalError := httpErr.HTTPCode == http.StatusNotFound || httpErr.HTTPCode == http.StatusBadRequest || httpErr.HTTPCode == http.StatusTooManyRequests
if isNormalError {
logvr(v, r).
Err(httpErr).
Debug("Connection closed with HTTP %d (ntfy error %d)", httpErr.HTTPCode, httpErr.Code)
logvr(v, r).Err(err).Debug("Connection closed with HTTP %d (ntfy error %d)", httpErr.HTTPCode, httpErr.Code)
} else {
logvr(v, r).
Err(httpErr).
Info("Connection closed with HTTP %d (ntfy error %d)", httpErr.HTTPCode, httpErr.Code)
logvr(v, r).Err(err).Info("Connection closed with HTTP %d (ntfy error %d)", httpErr.HTTPCode, httpErr.Code)
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
@ -629,7 +598,9 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
}
m.Sender = v.IP()
m.User = v.MaybeUserID()
m.Expires = time.Unix(m.Time, 0).Add(vRate.Limits().MessageExpiryDuration).Unix()
if cache {
m.Expires = time.Unix(m.Time, 0).Add(v.Limits().MessageExpiryDuration).Unix()
}
if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
return nil, err
}
@ -637,21 +608,18 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
m.Message = emptyMessageBody
}
delayed := m.Time > time.Now().Unix()
logvrm(vRate, r, m).
ev := logvrm(vRate, r, m).
Tag(tagPublish).
Fields(log.Context{
"message_delayed": delayed,
"message_firebase": firebase,
"message_unifiedpush": unifiedpush,
"message_email": email,
"message_subscriber_rate_limited": vRate != v,
}).
Debug("Received message")
if log.IsTrace() {
logvrm(vRate, r, m).
Tag(tagPublish).
Field("message_body", util.MaybeMarshalJSON(m)).
Trace("Message body")
"message_delayed": delayed,
"message_firebase": firebase,
"message_unifiedpush": unifiedpush,
"message_email": email,
})
if ev.IsTrace() {
ev.Field("message_body", util.MaybeMarshalJSON(m)).Trace("Received message")
} else if ev.IsDebug() {
ev.Debug("Received message")
}
if !delayed {
if err := t.Publish(v, m); err != nil {
@ -1176,7 +1144,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
return err
}
err = g.Wait()
if err != nil && websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
if err != nil && websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNoStatusReceived) {
logvr(v, r).Tag(tagWebsocket).Err(err).Fields(websocketErrorContext(err)).Trace("WebSocket connection closed")
return nil // Normal closures are not errors; note: "1006 (abnormal closure)" is treated as normal, because people disconnect a lot
}
@ -1309,161 +1277,6 @@ func (s *Server) topicFromID(id string) (*topic, error) {
return topics[0], nil
}
func (s *Server) execManager() {
// WARNING: Make sure to only selectively lock with the mutex, and be aware that this
// there is no mutex for the entire function.
// Expire visitors from rate visitors map
staleVisitors := 0
log.
Tag(tagManager).
Timing(func() {
s.mu.Lock()
defer s.mu.Unlock()
for ip, v := range s.visitors {
if v.Stale() {
log.Tag(tagManager).With(v).Trace("Deleting stale visitor")
delete(s.visitors, ip)
staleVisitors++
}
}
}).
Field("stale_visitors", staleVisitors).
Debug("Deleted %d stale visitor(s)", staleVisitors)
// Delete expired user tokens and users
if s.userManager != nil {
log.
Tag(tagManager).
Timing(func() {
if err := s.userManager.RemoveExpiredTokens(); err != nil {
log.Tag(tagManager).Err(err).Warn("Error expiring user tokens")
}
if err := s.userManager.RemoveDeletedUsers(); err != nil {
log.Tag(tagManager).Err(err).Warn("Error deleting soft-deleted users")
}
}).
Debug("Removed expired tokens and users")
}
// Delete expired attachments
if s.fileCache != nil {
log.
Tag(tagManager).
Timing(func() {
ids, err := s.messageCache.AttachmentsExpired()
if err != nil {
log.Tag(tagManager).Err(err).Warn("Error retrieving expired attachments")
} else if len(ids) > 0 {
if log.Tag(tagManager).IsDebug() {
log.Tag(tagManager).Debug("Deleting attachments %s", strings.Join(ids, ", "))
}
if err := s.fileCache.Remove(ids...); err != nil {
log.Tag(tagManager).Err(err).Warn("Error deleting attachments")
}
if err := s.messageCache.MarkAttachmentsDeleted(ids...); err != nil {
log.Tag(tagManager).Err(err).Warn("Error marking attachments deleted")
}
} else {
log.Tag(tagManager).Debug("No expired attachments to delete")
}
}).
Debug("Deleted expired attachments")
}
// Prune messages
log.
Tag(tagManager).
Timing(func() {
expiredMessageIDs, err := s.messageCache.MessagesExpired()
if err != nil {
log.Tag(tagManager).Err(err).Warn("Error retrieving expired messages")
} else if len(expiredMessageIDs) > 0 {
if err := s.fileCache.Remove(expiredMessageIDs...); err != nil {
log.Tag(tagManager).Err(err).Warn("Error deleting attachments for expired messages")
}
if err := s.messageCache.DeleteMessages(expiredMessageIDs...); err != nil {
log.Tag(tagManager).Err(err).Warn("Error marking attachments deleted")
}
} else {
log.Tag(tagManager).Debug("No expired messages to delete")
}
}).
Debug("Pruned messages")
// Message count per topic
var messagesCached int
messageCounts, err := s.messageCache.MessageCounts()
if err != nil {
log.Tag(tagManager).Err(err).Warn("Cannot get message counts")
messageCounts = make(map[string]int) // Empty, so we can continue
}
for _, count := range messageCounts {
messagesCached += count
}
// Remove subscriptions without subscribers
var emptyTopics, subscribers int
log.
Tag(tagManager).
Timing(func() {
s.mu.Lock()
defer s.mu.Unlock()
for _, t := range s.topics {
subs := t.SubscribersCount()
ev := log.Tag(tagManager)
if ev.IsTrace() {
expiryMessage := ""
if subs == 0 {
expiryTime := time.Until(t.vRateExpires)
expiryMessage = ", expires in " + expiryTime.String()
}
ev.Trace("- topic %s: %d subscribers%s", t.ID, subs, expiryMessage)
}
msgs, exists := messageCounts[t.ID]
if t.Stale() && (!exists || msgs == 0) {
log.Tag(tagManager).Trace("Deleting empty topic %s", t.ID)
emptyTopics++
delete(s.topics, t.ID)
continue
}
subscribers += subs
}
}).
Debug("Removed %d empty topic(s)", emptyTopics)
// Mail stats
var receivedMailTotal, receivedMailSuccess, receivedMailFailure int64
if s.smtpServerBackend != nil {
receivedMailTotal, receivedMailSuccess, receivedMailFailure = s.smtpServerBackend.Counts()
}
var sentMailTotal, sentMailSuccess, sentMailFailure int64
if s.smtpSender != nil {
sentMailTotal, sentMailSuccess, sentMailFailure = s.smtpSender.Counts()
}
// Print stats
s.mu.Lock()
messagesCount, topicsCount, visitorsCount := s.messages, len(s.topics), len(s.visitors)
s.mu.Unlock()
log.
Tag(tagManager).
Fields(log.Context{
"messages_published": messagesCount,
"messages_cached": messagesCached,
"topics_active": topicsCount,
"subscribers": subscribers,
"visitors": visitorsCount,
"emails_received": receivedMailTotal,
"emails_received_success": receivedMailSuccess,
"emails_received_failure": receivedMailFailure,
"emails_sent": sentMailTotal,
"emails_sent_success": sentMailSuccess,
"emails_sent_failure": sentMailFailure,
}).
Info("Server stats")
}
func (s *Server) runSMTPServer() error {
s.smtpServerBackend = newMailBackend(s.config, s.handle)
s.smtpServer = smtp.NewServer(s.smtpServerBackend)