1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2024-11-22 19:33:27 +01:00

Firebase quota limit

This commit is contained in:
Philipp Heckel 2022-05-31 20:38:56 -04:00
parent 8a81c8e95b
commit 8283b6be97
9 changed files with 180 additions and 119 deletions

View file

@ -6,15 +6,16 @@ import (
// Defines default config settings (excluding limits, see below) // Defines default config settings (excluding limits, see below)
const ( const (
DefaultListenHTTP = ":80" DefaultListenHTTP = ":80"
DefaultCacheDuration = 12 * time.Hour DefaultCacheDuration = 12 * time.Hour
DefaultKeepaliveInterval = 45 * time.Second // Not too frequently to save battery (Android read timeout used to be 77s!) DefaultKeepaliveInterval = 45 * time.Second // Not too frequently to save battery (Android read timeout used to be 77s!)
DefaultManagerInterval = time.Minute DefaultManagerInterval = time.Minute
DefaultAtSenderInterval = 10 * time.Second DefaultAtSenderInterval = 10 * time.Second
DefaultMinDelay = 10 * time.Second DefaultMinDelay = 10 * time.Second
DefaultMaxDelay = 3 * 24 * time.Hour DefaultMaxDelay = 3 * 24 * time.Hour
DefaultFirebaseKeepaliveInterval = 3 * time.Hour // ~control topic (Android), not too frequently to save battery DefaultFirebaseKeepaliveInterval = 3 * time.Hour // ~control topic (Android), not too frequently to save battery
DefaultFirebasePollInterval = 20 * time.Minute // ~poll topic (iOS), max. 2-3 times per hour (see docs) DefaultFirebasePollInterval = 20 * time.Minute // ~poll topic (iOS), max. 2-3 times per hour (see docs)
DefaultFirebaseQuotaLimitPenaltyDuration = 10 * time.Minute
) )
// Defines all global and per-visitor limits // Defines all global and per-visitor limits
@ -69,6 +70,7 @@ type Config struct {
AtSenderInterval time.Duration AtSenderInterval time.Duration
FirebaseKeepaliveInterval time.Duration FirebaseKeepaliveInterval time.Duration
FirebasePollInterval time.Duration FirebasePollInterval time.Duration
FirebaseQuotaLimitPenaltyDuration time.Duration
UpstreamBaseURL string UpstreamBaseURL string
SMTPSenderAddr string SMTPSenderAddr string
SMTPSenderUser string SMTPSenderUser string
@ -121,6 +123,7 @@ func NewConfig() *Config {
AtSenderInterval: DefaultAtSenderInterval, AtSenderInterval: DefaultAtSenderInterval,
FirebaseKeepaliveInterval: DefaultFirebaseKeepaliveInterval, FirebaseKeepaliveInterval: DefaultFirebaseKeepaliveInterval,
FirebasePollInterval: DefaultFirebasePollInterval, FirebasePollInterval: DefaultFirebasePollInterval,
FirebaseQuotaLimitPenaltyDuration: DefaultFirebaseQuotaLimitPenaltyDuration,
TotalTopicLimit: DefaultTotalTopicLimit, TotalTopicLimit: DefaultTotalTopicLimit,
VisitorSubscriptionLimit: DefaultVisitorSubscriptionLimit, VisitorSubscriptionLimit: DefaultVisitorSubscriptionLimit,
VisitorAttachmentTotalSizeLimit: DefaultVisitorAttachmentTotalSizeLimit, VisitorAttachmentTotalSizeLimit: DefaultVisitorAttachmentTotalSizeLimit,

View file

@ -59,6 +59,7 @@ var (
errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"} errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsLimitTotalTopics = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"} errHTTPTooManyRequestsLimitTotalTopics = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsAttachmentBandwidthLimit = &errHTTP{42905, http.StatusTooManyRequests, "too many requests: daily bandwidth limit reached", "https://ntfy.sh/docs/publish/#limitations"} errHTTPTooManyRequestsAttachmentBandwidthLimit = &errHTTP{42905, http.StatusTooManyRequests, "too many requests: daily bandwidth limit reached", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsFirebaseQuotaReached = &errHTTP{42906, http.StatusTooManyRequests, "too many requests: Firebase quota for topic reached", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""} errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
errHTTPInternalErrorInvalidFilePath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid file path", ""} errHTTPInternalErrorInvalidFilePath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid file path", ""}
) )

View file

@ -7,13 +7,11 @@ import (
"embed" "embed"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"log" "log"
"net" "net"
"net/http" "net/http"
"net/http/httptest"
"net/url" "net/url"
"os" "os"
"path" "path"
@ -221,7 +219,7 @@ func (s *Server) Run() error {
} }
s.mu.Unlock() s.mu.Unlock()
go s.runManager() go s.runManager()
go s.runAtSender() go s.runDelayedSender()
go s.runFirebaseKeepaliver() go s.runFirebaseKeepaliver()
return <-errChan return <-errChan
@ -435,7 +433,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
} }
delayed := m.Time > time.Now().Unix() delayed := m.Time > time.Now().Unix()
if !delayed { if !delayed {
if err := t.Publish(m); err != nil { if err := t.Publish(v, m); err != nil {
return err return err
} }
} }
@ -465,7 +463,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
} }
func (s *Server) sendToFirebase(v *visitor, m *message) { func (s *Server) sendToFirebase(v *visitor, m *message) {
if err := s.firebase(m); err != nil { if err := s.firebase(v, m); err != nil {
log.Printf("[%s] FB - Unable to publish to Firebase: %v", v.ip, err.Error()) log.Printf("[%s] FB - Unable to publish to Firebase: %v", v.ip, err.Error())
} }
} }
@ -731,7 +729,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
return err return err
} }
var wlock sync.Mutex var wlock sync.Mutex
sub := func(msg *message) error { sub := func(v *visitor, msg *message) error {
if !filters.Pass(msg) { if !filters.Pass(msg) {
return nil return nil
} }
@ -752,7 +750,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset! w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
if poll { if poll {
return s.sendOldMessages(topics, since, scheduled, sub) return s.sendOldMessages(topics, since, scheduled, v, sub)
} }
subscriberIDs := make([]int, 0) subscriberIDs := make([]int, 0)
for _, t := range topics { for _, t := range topics {
@ -763,10 +761,10 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
topics[i].Unsubscribe(subscriberID) // Order! topics[i].Unsubscribe(subscriberID) // Order!
} }
}() }()
if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
return err return err
} }
if err := s.sendOldMessages(topics, since, scheduled, sub); err != nil { if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
return err return err
} }
for { for {
@ -775,7 +773,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
return nil return nil
case <-time.After(s.config.KeepaliveInterval): case <-time.After(s.config.KeepaliveInterval):
v.Keepalive() v.Keepalive()
if err := sub(newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message if err := sub(v, newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
return err return err
} }
} }
@ -849,7 +847,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
} }
} }
}) })
sub := func(msg *message) error { sub := func(v *visitor, msg *message) error {
if !filters.Pass(msg) { if !filters.Pass(msg) {
return nil return nil
} }
@ -862,7 +860,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
} }
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
if poll { if poll {
return s.sendOldMessages(topics, since, scheduled, sub) return s.sendOldMessages(topics, since, scheduled, v, sub)
} }
subscriberIDs := make([]int, 0) subscriberIDs := make([]int, 0)
for _, t := range topics { for _, t := range topics {
@ -873,10 +871,10 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
topics[i].Unsubscribe(subscriberID) // Order! topics[i].Unsubscribe(subscriberID) // Order!
} }
}() }()
if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
return err return err
} }
if err := s.sendOldMessages(topics, since, scheduled, sub); err != nil { if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
return err return err
} }
err = g.Wait() err = g.Wait()
@ -900,7 +898,7 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu
return return
} }
func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, sub subscriber) error { func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error {
if since.IsNone() { if since.IsNone() {
return nil return nil
} }
@ -910,7 +908,7 @@ func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled b
return err return err
} }
for _, m := range messages { for _, m := range messages {
if err := sub(m); err != nil { if err := sub(v, m); err != nil {
return err return err
} }
} }
@ -1057,23 +1055,7 @@ func (s *Server) updateStatsAndPrune() {
} }
func (s *Server) runSMTPServer() error { func (s *Server) runSMTPServer() error {
sub := func(m *message) error { s.smtpBackend = newMailBackend(s.config, s.handle)
url := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic)
req, err := http.NewRequest("PUT", url, strings.NewReader(m.Message))
if err != nil {
return err
}
if m.Title != "" {
req.Header.Set("Title", m.Title)
}
rr := httptest.NewRecorder()
s.handle(rr, req)
if rr.Code != http.StatusOK {
return errors.New("error: " + rr.Body.String())
}
return nil
}
s.smtpBackend = newMailBackend(s.config, sub)
s.smtpServer = smtp.NewServer(s.smtpBackend) s.smtpServer = smtp.NewServer(s.smtpBackend)
s.smtpServer.Addr = s.config.SMTPServerListen s.smtpServer.Addr = s.config.SMTPServerListen
s.smtpServer.Domain = s.config.SMTPServerDomain s.smtpServer.Domain = s.config.SMTPServerDomain
@ -1096,7 +1078,7 @@ func (s *Server) runManager() {
} }
} }
func (s *Server) runAtSender() { func (s *Server) runDelayedSender() {
for { for {
select { select {
case <-time.After(s.config.AtSenderInterval): case <-time.After(s.config.AtSenderInterval):
@ -1113,14 +1095,15 @@ func (s *Server) runFirebaseKeepaliver() {
if s.firebase == nil { if s.firebase == nil {
return return
} }
v := newVisitor(s.config, s.messageCache, "0.0.0.0")
for { for {
select { select {
case <-time.After(s.config.FirebaseKeepaliveInterval): case <-time.After(s.config.FirebaseKeepaliveInterval):
if err := s.firebase(newKeepaliveMessage(firebaseControlTopic)); err != nil { if err := s.firebase(v, newKeepaliveMessage(firebaseControlTopic)); err != nil {
log.Printf("error sending Firebase keepalive message to %s: %s", firebaseControlTopic, err.Error()) log.Printf("error sending Firebase keepalive message to %s: %s", firebaseControlTopic, err.Error())
} }
case <-time.After(s.config.FirebasePollInterval): case <-time.After(s.config.FirebasePollInterval):
if err := s.firebase(newKeepaliveMessage(firebasePollTopic)); err != nil { if err := s.firebase(v, newKeepaliveMessage(firebasePollTopic)); err != nil {
log.Printf("error sending Firebase keepalive message to %s: %s", firebasePollTopic, err.Error()) log.Printf("error sending Firebase keepalive message to %s: %s", firebasePollTopic, err.Error())
} }
case <-s.closeChan: case <-s.closeChan:
@ -1130,28 +1113,36 @@ func (s *Server) runFirebaseKeepaliver() {
} }
func (s *Server) sendDelayedMessages() error { func (s *Server) sendDelayedMessages() error {
s.mu.Lock()
defer s.mu.Unlock()
messages, err := s.messageCache.MessagesDue() messages, err := s.messageCache.MessagesDue()
if err != nil { if err != nil {
return err return err
} }
for _, m := range messages { for _, m := range messages {
t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published v := s.visitorFromIP("0.0.0.0") // FIXME: get message owner!!
if ok { if err := s.sendDelayedMessage(v, m); err != nil {
if err := t.Publish(m); err != nil { log.Printf("error sending delayed message: %s", err.Error())
log.Printf("unable to publish message %s to topic %s: %v", m.ID, m.Topic, err.Error())
}
} }
if s.firebase != nil { // Firebase subscribers may not show up in topics map }
if err := s.firebase(m); err != nil { return nil
log.Printf("unable to publish to Firebase: %v", err.Error()) }
}
func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
if ok {
if err := t.Publish(v, m); err != nil {
return fmt.Errorf("unable to publish message %s to topic %s: %v", m.ID, m.Topic, err.Error())
} }
if err := s.messageCache.MarkPublished(m); err != nil { }
return err if s.firebase != nil { // Firebase subscribers may not show up in topics map
if err := s.firebase(v, m); err != nil {
return fmt.Errorf("unable to publish to Firebase: %v", err.Error())
} }
} }
if err := s.messageCache.MarkPublished(m); err != nil {
return err
}
return nil return nil
} }
@ -1290,8 +1281,6 @@ func extractUserPass(r *http.Request) (username string, password string, ok bool
// visitor creates or retrieves a rate.Limiter for the given visitor. // visitor creates or retrieves a rate.Limiter for the given visitor.
// This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT). // This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
func (s *Server) visitor(r *http.Request) *visitor { func (s *Server) visitor(r *http.Request) *visitor {
s.mu.Lock()
defer s.mu.Unlock()
remoteAddr := r.RemoteAddr remoteAddr := r.RemoteAddr
ip, _, err := net.SplitHostPort(remoteAddr) ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil { if err != nil {
@ -1300,6 +1289,12 @@ func (s *Server) visitor(r *http.Request) *visitor {
if s.config.BehindProxy && r.Header.Get("X-Forwarded-For") != "" { if s.config.BehindProxy && r.Header.Get("X-Forwarded-For") != "" {
ip = r.Header.Get("X-Forwarded-For") ip = r.Header.Get("X-Forwarded-For")
} }
return s.visitorFromIP(ip)
}
func (s *Server) visitorFromIP(ip string) *visitor {
s.mu.Lock()
defer s.mu.Unlock()
v, exists := s.visitors[ip] v, exists := s.visitors[ip]
if !exists { if !exists {
s.visitors[ip] = newVisitor(s.config, s.messageCache, ip) s.visitors[ip] = newVisitor(s.config, s.messageCache, ip)

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log"
"strings" "strings"
firebase "firebase.google.com/go/v4" firebase "firebase.google.com/go/v4"
@ -26,12 +27,20 @@ func createFirebaseSubscriber(credentialsFile string, auther auth.Auther) (subsc
if err != nil { if err != nil {
return nil, err return nil, err
} }
return func(m *message) error { return func(v *visitor, m *message) error {
if err := v.FirebaseAllowed(); err != nil {
return errHTTPTooManyRequestsFirebaseQuotaReached
}
fbm, err := toFirebaseMessage(m, auther) fbm, err := toFirebaseMessage(m, auther)
if err != nil { if err != nil {
return err return err
} }
_, err = msg.Send(context.Background(), fbm) _, err = msg.Send(context.Background(), fbm)
if err != nil && messaging.IsQuotaExceeded(err) {
log.Printf("[%s] FB quota exceeded when trying to publish to topic %s, temporarily denying FB access", v.ip, m.Topic)
v.FirebaseTemporarilyDeny()
return errHTTPTooManyRequestsFirebaseQuotaReached
}
return err return err
}, nil }, nil
} }

View file

@ -469,7 +469,8 @@ func TestServer_PublishFirebase(t *testing.T) {
require.NotEmpty(t, msg.ID) require.NotEmpty(t, msg.ID)
// Keepalive message // Keepalive message
require.Nil(t, s.firebase(newKeepaliveMessage(firebaseControlTopic))) v := newVisitor(s.config, s.messageCache, "1.2.3.4")
require.Nil(t, s.firebase(v, newKeepaliveMessage(firebaseControlTopic)))
time.Sleep(500 * time.Millisecond) // Time for sends time.Sleep(500 * time.Millisecond) // Time for sends
} }

View file

@ -3,10 +3,13 @@ package server
import ( import (
"bytes" "bytes"
"errors" "errors"
"fmt"
"github.com/emersion/go-smtp" "github.com/emersion/go-smtp"
"io" "io"
"mime" "mime"
"mime/multipart" "mime/multipart"
"net/http"
"net/http/httptest"
"net/mail" "net/mail"
"strings" "strings"
"sync" "sync"
@ -23,25 +26,25 @@ var (
// smtpBackend implements SMTP server methods. // smtpBackend implements SMTP server methods.
type smtpBackend struct { type smtpBackend struct {
config *Config config *Config
sub subscriber handler func(http.ResponseWriter, *http.Request)
success int64 success int64
failure int64 failure int64
mu sync.Mutex mu sync.Mutex
} }
func newMailBackend(conf *Config, sub subscriber) *smtpBackend { func newMailBackend(conf *Config, handler func(http.ResponseWriter, *http.Request)) *smtpBackend {
return &smtpBackend{ return &smtpBackend{
config: conf, config: conf,
sub: sub, handler: handler,
} }
} }
func (b *smtpBackend) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) { func (b *smtpBackend) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) {
return &smtpSession{backend: b}, nil return &smtpSession{backend: b, remoteAddr: state.RemoteAddr.String()}, nil
} }
func (b *smtpBackend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) { func (b *smtpBackend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) {
return &smtpSession{backend: b}, nil return &smtpSession{backend: b, remoteAddr: state.RemoteAddr.String()}, nil
} }
func (b *smtpBackend) Counts() (success int64, failure int64) { func (b *smtpBackend) Counts() (success int64, failure int64) {
@ -52,9 +55,10 @@ func (b *smtpBackend) Counts() (success int64, failure int64) {
// smtpSession is returned after EHLO. // smtpSession is returned after EHLO.
type smtpSession struct { type smtpSession struct {
backend *smtpBackend backend *smtpBackend
topic string remoteAddr string
mu sync.Mutex topic string
mu sync.Mutex
} }
func (s *smtpSession) AuthPlain(username, password string) error { func (s *smtpSession) AuthPlain(username, password string) error {
@ -128,7 +132,7 @@ func (s *smtpSession) Data(r io.Reader) error {
m.Message = m.Title // Flip them, this makes more sense m.Message = m.Title // Flip them, this makes more sense
m.Title = "" m.Title = ""
} }
if err := s.backend.sub(m); err != nil { if err := s.publishMessage(m); err != nil {
return err return err
} }
s.backend.mu.Lock() s.backend.mu.Lock()
@ -138,6 +142,24 @@ func (s *smtpSession) Data(r io.Reader) error {
}) })
} }
func (s *smtpSession) publishMessage(m *message) error {
url := fmt.Sprintf("%s/%s", s.backend.config.BaseURL, m.Topic)
req, err := http.NewRequest("PUT", url, strings.NewReader(m.Message))
req.RemoteAddr = s.remoteAddr // rate limiting!!
if err != nil {
return err
}
if m.Title != "" {
req.Header.Set("Title", m.Title)
}
rr := httptest.NewRecorder()
s.backend.handler(rr, req)
if rr.Code != http.StatusOK {
return errors.New("error: " + rr.Body.String())
}
return nil
}
func (s *smtpSession) Reset() { func (s *smtpSession) Reset() {
s.mu.Lock() s.mu.Lock()
s.topic = "" s.topic = ""

View file

@ -3,6 +3,9 @@ package server
import ( import (
"github.com/emersion/go-smtp" "github.com/emersion/go-smtp"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"io"
"net"
"net/http"
"strings" "strings"
"testing" "testing"
) )
@ -27,13 +30,12 @@ Content-Type: text/html; charset="UTF-8"
<div dir="ltr">what&#39;s up<br clear="all"><div><br></div></div> <div dir="ltr">what&#39;s up<br clear="all"><div><br></div></div>
--000000000000f3320b05d42915c9--` --000000000000f3320b05d42915c9--`
_, backend := newTestBackend(t, func(m *message) error { _, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "mytopic", m.Topic) require.Equal(t, "/mytopic", r.URL.Path)
require.Equal(t, "and one more", m.Title) require.Equal(t, "and one more", r.Header.Get("Title"))
require.Equal(t, "what's up", m.Message) require.Equal(t, "what's up", readAll(t, r.Body))
return nil
}) })
session, _ := backend.AnonymousLogin(nil) session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh")) require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email))) require.Nil(t, session.Data(strings.NewReader(email)))
@ -59,13 +61,12 @@ Content-Type: text/html; charset="UTF-8"
<div dir="ltr"><br></div> <div dir="ltr"><br></div>
--000000000000bcf4a405d429f8d4--` --000000000000bcf4a405d429f8d4--`
_, backend := newTestBackend(t, func(m *message) error { _, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "emailtest", m.Topic) require.Equal(t, "/emailtest", r.URL.Path)
require.Equal(t, "", m.Title) // We flipped message and body require.Equal(t, "", r.Header.Get("Title")) // We flipped message and body
require.Equal(t, "This email has a subject but no body", m.Message) require.Equal(t, "This email has a subject but no body", readAll(t, r.Body))
return nil
}) })
session, _ := backend.AnonymousLogin(nil) session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("ntfy-emailtest@ntfy.sh")) require.Nil(t, session.Rcpt("ntfy-emailtest@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email))) require.Nil(t, session.Data(strings.NewReader(email)))
@ -81,14 +82,13 @@ Content-Type: text/plain; charset="UTF-8"
what's up what's up
` `
conf, backend := newTestBackend(t, func(m *message) error { conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "mytopic", m.Topic) require.Equal(t, "/mytopic", r.URL.Path)
require.Equal(t, "and one more", m.Title) require.Equal(t, "and one more", r.Header.Get("Title"))
require.Equal(t, "what's up", m.Message) require.Equal(t, "what's up", readAll(t, r.Body))
return nil
}) })
conf.SMTPServerAddrPrefix = "" conf.SMTPServerAddrPrefix = ""
session, _ := backend.AnonymousLogin(nil) session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("mytopic@ntfy.sh")) require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email))) require.Nil(t, session.Data(strings.NewReader(email)))
@ -99,14 +99,13 @@ func TestSmtpBackend_Plaintext_No_ContentType(t *testing.T) {
what's up what's up
` `
conf, backend := newTestBackend(t, func(m *message) error { conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "mytopic", m.Topic) require.Equal(t, "/mytopic", r.URL.Path)
require.Equal(t, "Very short mail", m.Title) require.Equal(t, "Very short mail", r.Header.Get("Title"))
require.Equal(t, "what's up", m.Message) require.Equal(t, "what's up", readAll(t, r.Body))
return nil
}) })
conf.SMTPServerAddrPrefix = "" conf.SMTPServerAddrPrefix = ""
session, _ := backend.AnonymousLogin(nil) session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("mytopic@ntfy.sh")) require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email))) require.Nil(t, session.Data(strings.NewReader(email)))
@ -121,11 +120,10 @@ Content-Type: text/plain; charset="UTF-8"
what's up what's up
` `
_, backend := newTestBackend(t, func(m *message) error { _, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "Three santas 🎅🎅🎅", m.Title) require.Equal(t, "Three santas 🎅🎅🎅", r.Header.Get("Title"))
return nil
}) })
session, _ := backend.AnonymousLogin(nil) session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh")) require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email))) require.Nil(t, session.Data(strings.NewReader(email)))
@ -140,7 +138,7 @@ To: mytopic@ntfy.sh
Content-Type: text/plain; charset="UTF-8" Content-Type: text/plain; charset="UTF-8"
you know this is a string. you know this is a string.
it's a long string. it's a long string.
it's supposed to be longer than the max message length it's supposed to be longer than the max message length
which is 4096 bytes, which is 4096 bytes,
it used to be 512 bytes, but I increased that for the UnifiedPush support it used to be 512 bytes, but I increased that for the UnifiedPush support
@ -204,9 +202,9 @@ BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
that should do it that should do it
` `
conf, backend := newTestBackend(t, func(m *message) error { conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
expected := `you know this is a string. expected := `you know this is a string.
it's a long string. it's a long string.
it's supposed to be longer than the max message length it's supposed to be longer than the max message length
which is 4096 bytes, which is 4096 bytes,
it used to be 512 bytes, but I increased that for the UnifiedPush support it used to be 512 bytes, but I increased that for the UnifiedPush support
@ -266,13 +264,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
...................................................................... ......................................................................
...................................................................... ......................................................................
and with BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB and with BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
BBBBBBBBBBBBBBBBBBBBBBBB` BBBBBBBBBBBBBBBBBBBBBBBBB`
require.Equal(t, 4096, len(expected)) // Sanity check require.Equal(t, 4096, len(expected)) // Sanity check
require.Equal(t, expected, m.Message) require.Equal(t, expected, readAll(t, r.Body))
return nil
}) })
conf.SMTPServerAddrPrefix = "" conf.SMTPServerAddrPrefix = ""
session, _ := backend.AnonymousLogin(nil) session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("mytopic@ntfy.sh")) require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
require.Nil(t, session.Data(strings.NewReader(email))) require.Nil(t, session.Data(strings.NewReader(email)))
@ -288,21 +285,41 @@ Content-Type: text/SOMETHINGELSE
what's up what's up
` `
conf, backend := newTestBackend(t, func(m *message) error { conf, backend := newTestBackend(t, func(http.ResponseWriter, *http.Request) {
return nil // Nothing.
}) })
conf.SMTPServerAddrPrefix = "" conf.SMTPServerAddrPrefix = ""
session, _ := backend.Login(nil, "user", "pass") session, _ := backend.Login(fakeConnState(t, "1.2.3.4"), "user", "pass")
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{})) require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
require.Nil(t, session.Rcpt("mytopic@ntfy.sh")) require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
require.Equal(t, errUnsupportedContentType, session.Data(strings.NewReader(email))) require.Equal(t, errUnsupportedContentType, session.Data(strings.NewReader(email)))
} }
func newTestBackend(t *testing.T, sub subscriber) (*Config, *smtpBackend) { func newTestBackend(t *testing.T, handler func(http.ResponseWriter, *http.Request)) (*Config, *smtpBackend) {
conf := newTestConfig(t) conf := newTestConfig(t)
conf.SMTPServerListen = ":25" conf.SMTPServerListen = ":25"
conf.SMTPServerDomain = "ntfy.sh" conf.SMTPServerDomain = "ntfy.sh"
conf.SMTPServerAddrPrefix = "ntfy-" conf.SMTPServerAddrPrefix = "ntfy-"
backend := newMailBackend(conf, sub) backend := newMailBackend(conf, handler)
return conf, backend return conf, backend
} }
func readAll(t *testing.T, rc io.ReadCloser) string {
b, err := io.ReadAll(rc)
if err != nil {
t.Fatal(err)
}
return string(b)
}
func fakeConnState(t *testing.T, remoteAddr string) *smtp.ConnectionState {
ip, err := net.ResolveIPAddr("ip", remoteAddr)
if err != nil {
t.Fatal(err)
}
return &smtp.ConnectionState{
Hostname: "myhostname",
LocalAddr: ip,
RemoteAddr: ip,
}
}

View file

@ -15,7 +15,7 @@ type topic struct {
} }
// subscriber is a function that is called for every new message on a topic // subscriber is a function that is called for every new message on a topic
type subscriber func(msg *message) error type subscriber func(v *visitor, msg *message) error
// newTopic creates a new topic // newTopic creates a new topic
func newTopic(id string) *topic { func newTopic(id string) *topic {
@ -42,12 +42,12 @@ func (t *topic) Unsubscribe(id int) {
} }
// Publish asynchronously publishes to all subscribers // Publish asynchronously publishes to all subscribers
func (t *topic) Publish(m *message) error { func (t *topic) Publish(v *visitor, m *message) error {
go func() { go func() {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
for _, s := range t.subscribers { for _, s := range t.subscribers {
if err := s(m); err != nil { if err := s(v, m); err != nil {
log.Printf("error publishing message to subscriber") log.Printf("error publishing message to subscriber")
} }
} }

View file

@ -28,6 +28,7 @@ type visitor struct {
emails *rate.Limiter emails *rate.Limiter
subscriptions util.Limiter subscriptions util.Limiter
bandwidth util.Limiter bandwidth util.Limiter
firebase time.Time // Next allowed Firebase message
seen time.Time seen time.Time
mu sync.Mutex mu sync.Mutex
} }
@ -48,14 +49,11 @@ func newVisitor(conf *Config, messageCache *messageCache, ip string) *visitor {
emails: rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst), emails: rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst),
subscriptions: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)), subscriptions: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
bandwidth: util.NewBytesLimiter(conf.VisitorAttachmentDailyBandwidthLimit, 24*time.Hour), bandwidth: util.NewBytesLimiter(conf.VisitorAttachmentDailyBandwidthLimit, 24*time.Hour),
firebase: time.Unix(0, 0),
seen: time.Now(), seen: time.Now(),
} }
} }
func (v *visitor) IP() string {
return v.ip
}
func (v *visitor) RequestAllowed() error { func (v *visitor) RequestAllowed() error {
if !v.requests.Allow() { if !v.requests.Allow() {
return errVisitorLimitReached return errVisitorLimitReached
@ -63,6 +61,21 @@ func (v *visitor) RequestAllowed() error {
return nil return nil
} }
func (v *visitor) FirebaseAllowed() error {
v.mu.Lock()
defer v.mu.Unlock()
if time.Now().Before(v.firebase) {
return errVisitorLimitReached
}
return nil
}
func (v *visitor) FirebaseTemporarilyDeny() {
v.mu.Lock()
defer v.mu.Unlock()
v.firebase = time.Now().Add(v.config.FirebaseQuotaLimitPenaltyDuration)
}
func (v *visitor) EmailAllowed() error { func (v *visitor) EmailAllowed() error {
if !v.emails.Allow() { if !v.emails.Allow() {
return errVisitorLimitReached return errVisitorLimitReached