mirror of
https://github.com/binwiederhier/ntfy.git
synced 2024-11-22 19:33:27 +01:00
Firebase quota limit
This commit is contained in:
parent
8a81c8e95b
commit
8283b6be97
9 changed files with 180 additions and 119 deletions
|
@ -15,6 +15,7 @@ const (
|
|||
DefaultMaxDelay = 3 * 24 * time.Hour
|
||||
DefaultFirebaseKeepaliveInterval = 3 * time.Hour // ~control topic (Android), not too frequently to save battery
|
||||
DefaultFirebasePollInterval = 20 * time.Minute // ~poll topic (iOS), max. 2-3 times per hour (see docs)
|
||||
DefaultFirebaseQuotaLimitPenaltyDuration = 10 * time.Minute
|
||||
)
|
||||
|
||||
// Defines all global and per-visitor limits
|
||||
|
@ -69,6 +70,7 @@ type Config struct {
|
|||
AtSenderInterval time.Duration
|
||||
FirebaseKeepaliveInterval time.Duration
|
||||
FirebasePollInterval time.Duration
|
||||
FirebaseQuotaLimitPenaltyDuration time.Duration
|
||||
UpstreamBaseURL string
|
||||
SMTPSenderAddr string
|
||||
SMTPSenderUser string
|
||||
|
@ -121,6 +123,7 @@ func NewConfig() *Config {
|
|||
AtSenderInterval: DefaultAtSenderInterval,
|
||||
FirebaseKeepaliveInterval: DefaultFirebaseKeepaliveInterval,
|
||||
FirebasePollInterval: DefaultFirebasePollInterval,
|
||||
FirebaseQuotaLimitPenaltyDuration: DefaultFirebaseQuotaLimitPenaltyDuration,
|
||||
TotalTopicLimit: DefaultTotalTopicLimit,
|
||||
VisitorSubscriptionLimit: DefaultVisitorSubscriptionLimit,
|
||||
VisitorAttachmentTotalSizeLimit: DefaultVisitorAttachmentTotalSizeLimit,
|
||||
|
|
|
@ -59,6 +59,7 @@ var (
|
|||
errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
|
||||
errHTTPTooManyRequestsLimitTotalTopics = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"}
|
||||
errHTTPTooManyRequestsAttachmentBandwidthLimit = &errHTTP{42905, http.StatusTooManyRequests, "too many requests: daily bandwidth limit reached", "https://ntfy.sh/docs/publish/#limitations"}
|
||||
errHTTPTooManyRequestsFirebaseQuotaReached = &errHTTP{42906, http.StatusTooManyRequests, "too many requests: Firebase quota for topic reached", "https://ntfy.sh/docs/publish/#limitations"}
|
||||
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
|
||||
errHTTPInternalErrorInvalidFilePath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid file path", ""}
|
||||
)
|
||||
|
|
|
@ -7,13 +7,11 @@ import (
|
|||
"embed"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
|
@ -221,7 +219,7 @@ func (s *Server) Run() error {
|
|||
}
|
||||
s.mu.Unlock()
|
||||
go s.runManager()
|
||||
go s.runAtSender()
|
||||
go s.runDelayedSender()
|
||||
go s.runFirebaseKeepaliver()
|
||||
|
||||
return <-errChan
|
||||
|
@ -435,7 +433,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
|
|||
}
|
||||
delayed := m.Time > time.Now().Unix()
|
||||
if !delayed {
|
||||
if err := t.Publish(m); err != nil {
|
||||
if err := t.Publish(v, m); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -465,7 +463,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
|
|||
}
|
||||
|
||||
func (s *Server) sendToFirebase(v *visitor, m *message) {
|
||||
if err := s.firebase(m); err != nil {
|
||||
if err := s.firebase(v, m); err != nil {
|
||||
log.Printf("[%s] FB - Unable to publish to Firebase: %v", v.ip, err.Error())
|
||||
}
|
||||
}
|
||||
|
@ -731,7 +729,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
|||
return err
|
||||
}
|
||||
var wlock sync.Mutex
|
||||
sub := func(msg *message) error {
|
||||
sub := func(v *visitor, msg *message) error {
|
||||
if !filters.Pass(msg) {
|
||||
return nil
|
||||
}
|
||||
|
@ -752,7 +750,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
|||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
|
||||
if poll {
|
||||
return s.sendOldMessages(topics, since, scheduled, sub)
|
||||
return s.sendOldMessages(topics, since, scheduled, v, sub)
|
||||
}
|
||||
subscriberIDs := make([]int, 0)
|
||||
for _, t := range topics {
|
||||
|
@ -763,10 +761,10 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
|||
topics[i].Unsubscribe(subscriberID) // Order!
|
||||
}
|
||||
}()
|
||||
if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message
|
||||
if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
|
||||
return err
|
||||
}
|
||||
if err := s.sendOldMessages(topics, since, scheduled, sub); err != nil {
|
||||
if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
|
||||
return err
|
||||
}
|
||||
for {
|
||||
|
@ -775,7 +773,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
|||
return nil
|
||||
case <-time.After(s.config.KeepaliveInterval):
|
||||
v.Keepalive()
|
||||
if err := sub(newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
|
||||
if err := sub(v, newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -849,7 +847,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
|||
}
|
||||
}
|
||||
})
|
||||
sub := func(msg *message) error {
|
||||
sub := func(v *visitor, msg *message) error {
|
||||
if !filters.Pass(msg) {
|
||||
return nil
|
||||
}
|
||||
|
@ -862,7 +860,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
|||
}
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||
if poll {
|
||||
return s.sendOldMessages(topics, since, scheduled, sub)
|
||||
return s.sendOldMessages(topics, since, scheduled, v, sub)
|
||||
}
|
||||
subscriberIDs := make([]int, 0)
|
||||
for _, t := range topics {
|
||||
|
@ -873,10 +871,10 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
|||
topics[i].Unsubscribe(subscriberID) // Order!
|
||||
}
|
||||
}()
|
||||
if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message
|
||||
if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
|
||||
return err
|
||||
}
|
||||
if err := s.sendOldMessages(topics, since, scheduled, sub); err != nil {
|
||||
if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
|
||||
return err
|
||||
}
|
||||
err = g.Wait()
|
||||
|
@ -900,7 +898,7 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu
|
|||
return
|
||||
}
|
||||
|
||||
func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, sub subscriber) error {
|
||||
func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error {
|
||||
if since.IsNone() {
|
||||
return nil
|
||||
}
|
||||
|
@ -910,7 +908,7 @@ func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled b
|
|||
return err
|
||||
}
|
||||
for _, m := range messages {
|
||||
if err := sub(m); err != nil {
|
||||
if err := sub(v, m); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -1057,23 +1055,7 @@ func (s *Server) updateStatsAndPrune() {
|
|||
}
|
||||
|
||||
func (s *Server) runSMTPServer() error {
|
||||
sub := func(m *message) error {
|
||||
url := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic)
|
||||
req, err := http.NewRequest("PUT", url, strings.NewReader(m.Message))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if m.Title != "" {
|
||||
req.Header.Set("Title", m.Title)
|
||||
}
|
||||
rr := httptest.NewRecorder()
|
||||
s.handle(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
return errors.New("error: " + rr.Body.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
s.smtpBackend = newMailBackend(s.config, sub)
|
||||
s.smtpBackend = newMailBackend(s.config, s.handle)
|
||||
s.smtpServer = smtp.NewServer(s.smtpBackend)
|
||||
s.smtpServer.Addr = s.config.SMTPServerListen
|
||||
s.smtpServer.Domain = s.config.SMTPServerDomain
|
||||
|
@ -1096,7 +1078,7 @@ func (s *Server) runManager() {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Server) runAtSender() {
|
||||
func (s *Server) runDelayedSender() {
|
||||
for {
|
||||
select {
|
||||
case <-time.After(s.config.AtSenderInterval):
|
||||
|
@ -1113,14 +1095,15 @@ func (s *Server) runFirebaseKeepaliver() {
|
|||
if s.firebase == nil {
|
||||
return
|
||||
}
|
||||
v := newVisitor(s.config, s.messageCache, "0.0.0.0")
|
||||
for {
|
||||
select {
|
||||
case <-time.After(s.config.FirebaseKeepaliveInterval):
|
||||
if err := s.firebase(newKeepaliveMessage(firebaseControlTopic)); err != nil {
|
||||
if err := s.firebase(v, newKeepaliveMessage(firebaseControlTopic)); err != nil {
|
||||
log.Printf("error sending Firebase keepalive message to %s: %s", firebaseControlTopic, err.Error())
|
||||
}
|
||||
case <-time.After(s.config.FirebasePollInterval):
|
||||
if err := s.firebase(newKeepaliveMessage(firebasePollTopic)); err != nil {
|
||||
if err := s.firebase(v, newKeepaliveMessage(firebasePollTopic)); err != nil {
|
||||
log.Printf("error sending Firebase keepalive message to %s: %s", firebasePollTopic, err.Error())
|
||||
}
|
||||
case <-s.closeChan:
|
||||
|
@ -1130,28 +1113,36 @@ func (s *Server) runFirebaseKeepaliver() {
|
|||
}
|
||||
|
||||
func (s *Server) sendDelayedMessages() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
messages, err := s.messageCache.MessagesDue()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, m := range messages {
|
||||
v := s.visitorFromIP("0.0.0.0") // FIXME: get message owner!!
|
||||
if err := s.sendDelayedMessage(v, m); err != nil {
|
||||
log.Printf("error sending delayed message: %s", err.Error())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
|
||||
if ok {
|
||||
if err := t.Publish(m); err != nil {
|
||||
log.Printf("unable to publish message %s to topic %s: %v", m.ID, m.Topic, err.Error())
|
||||
if err := t.Publish(v, m); err != nil {
|
||||
return fmt.Errorf("unable to publish message %s to topic %s: %v", m.ID, m.Topic, err.Error())
|
||||
}
|
||||
}
|
||||
if s.firebase != nil { // Firebase subscribers may not show up in topics map
|
||||
if err := s.firebase(m); err != nil {
|
||||
log.Printf("unable to publish to Firebase: %v", err.Error())
|
||||
if err := s.firebase(v, m); err != nil {
|
||||
return fmt.Errorf("unable to publish to Firebase: %v", err.Error())
|
||||
}
|
||||
}
|
||||
if err := s.messageCache.MarkPublished(m); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -1290,8 +1281,6 @@ func extractUserPass(r *http.Request) (username string, password string, ok bool
|
|||
// visitor creates or retrieves a rate.Limiter for the given visitor.
|
||||
// This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
|
||||
func (s *Server) visitor(r *http.Request) *visitor {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
remoteAddr := r.RemoteAddr
|
||||
ip, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
|
@ -1300,6 +1289,12 @@ func (s *Server) visitor(r *http.Request) *visitor {
|
|||
if s.config.BehindProxy && r.Header.Get("X-Forwarded-For") != "" {
|
||||
ip = r.Header.Get("X-Forwarded-For")
|
||||
}
|
||||
return s.visitorFromIP(ip)
|
||||
}
|
||||
|
||||
func (s *Server) visitorFromIP(ip string) *visitor {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
v, exists := s.visitors[ip]
|
||||
if !exists {
|
||||
s.visitors[ip] = newVisitor(s.config, s.messageCache, ip)
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
firebase "firebase.google.com/go/v4"
|
||||
|
@ -26,12 +27,20 @@ func createFirebaseSubscriber(credentialsFile string, auther auth.Auther) (subsc
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return func(m *message) error {
|
||||
return func(v *visitor, m *message) error {
|
||||
if err := v.FirebaseAllowed(); err != nil {
|
||||
return errHTTPTooManyRequestsFirebaseQuotaReached
|
||||
}
|
||||
fbm, err := toFirebaseMessage(m, auther)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = msg.Send(context.Background(), fbm)
|
||||
if err != nil && messaging.IsQuotaExceeded(err) {
|
||||
log.Printf("[%s] FB quota exceeded when trying to publish to topic %s, temporarily denying FB access", v.ip, m.Topic)
|
||||
v.FirebaseTemporarilyDeny()
|
||||
return errHTTPTooManyRequestsFirebaseQuotaReached
|
||||
}
|
||||
return err
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -469,7 +469,8 @@ func TestServer_PublishFirebase(t *testing.T) {
|
|||
require.NotEmpty(t, msg.ID)
|
||||
|
||||
// Keepalive message
|
||||
require.Nil(t, s.firebase(newKeepaliveMessage(firebaseControlTopic)))
|
||||
v := newVisitor(s.config, s.messageCache, "1.2.3.4")
|
||||
require.Nil(t, s.firebase(v, newKeepaliveMessage(firebaseControlTopic)))
|
||||
|
||||
time.Sleep(500 * time.Millisecond) // Time for sends
|
||||
}
|
||||
|
|
|
@ -3,10 +3,13 @@ package server
|
|||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/emersion/go-smtp"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -23,25 +26,25 @@ var (
|
|||
// smtpBackend implements SMTP server methods.
|
||||
type smtpBackend struct {
|
||||
config *Config
|
||||
sub subscriber
|
||||
handler func(http.ResponseWriter, *http.Request)
|
||||
success int64
|
||||
failure int64
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newMailBackend(conf *Config, sub subscriber) *smtpBackend {
|
||||
func newMailBackend(conf *Config, handler func(http.ResponseWriter, *http.Request)) *smtpBackend {
|
||||
return &smtpBackend{
|
||||
config: conf,
|
||||
sub: sub,
|
||||
handler: handler,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *smtpBackend) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) {
|
||||
return &smtpSession{backend: b}, nil
|
||||
return &smtpSession{backend: b, remoteAddr: state.RemoteAddr.String()}, nil
|
||||
}
|
||||
|
||||
func (b *smtpBackend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) {
|
||||
return &smtpSession{backend: b}, nil
|
||||
return &smtpSession{backend: b, remoteAddr: state.RemoteAddr.String()}, nil
|
||||
}
|
||||
|
||||
func (b *smtpBackend) Counts() (success int64, failure int64) {
|
||||
|
@ -53,6 +56,7 @@ func (b *smtpBackend) Counts() (success int64, failure int64) {
|
|||
// smtpSession is returned after EHLO.
|
||||
type smtpSession struct {
|
||||
backend *smtpBackend
|
||||
remoteAddr string
|
||||
topic string
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
@ -128,7 +132,7 @@ func (s *smtpSession) Data(r io.Reader) error {
|
|||
m.Message = m.Title // Flip them, this makes more sense
|
||||
m.Title = ""
|
||||
}
|
||||
if err := s.backend.sub(m); err != nil {
|
||||
if err := s.publishMessage(m); err != nil {
|
||||
return err
|
||||
}
|
||||
s.backend.mu.Lock()
|
||||
|
@ -138,6 +142,24 @@ func (s *smtpSession) Data(r io.Reader) error {
|
|||
})
|
||||
}
|
||||
|
||||
func (s *smtpSession) publishMessage(m *message) error {
|
||||
url := fmt.Sprintf("%s/%s", s.backend.config.BaseURL, m.Topic)
|
||||
req, err := http.NewRequest("PUT", url, strings.NewReader(m.Message))
|
||||
req.RemoteAddr = s.remoteAddr // rate limiting!!
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if m.Title != "" {
|
||||
req.Header.Set("Title", m.Title)
|
||||
}
|
||||
rr := httptest.NewRecorder()
|
||||
s.backend.handler(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
return errors.New("error: " + rr.Body.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *smtpSession) Reset() {
|
||||
s.mu.Lock()
|
||||
s.topic = ""
|
||||
|
|
|
@ -3,6 +3,9 @@ package server
|
|||
import (
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/stretchr/testify/require"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
@ -27,13 +30,12 @@ Content-Type: text/html; charset="UTF-8"
|
|||
<div dir="ltr">what's up<br clear="all"><div><br></div></div>
|
||||
|
||||
--000000000000f3320b05d42915c9--`
|
||||
_, backend := newTestBackend(t, func(m *message) error {
|
||||
require.Equal(t, "mytopic", m.Topic)
|
||||
require.Equal(t, "and one more", m.Title)
|
||||
require.Equal(t, "what's up", m.Message)
|
||||
return nil
|
||||
_, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/mytopic", r.URL.Path)
|
||||
require.Equal(t, "and one more", r.Header.Get("Title"))
|
||||
require.Equal(t, "what's up", readAll(t, r.Body))
|
||||
})
|
||||
session, _ := backend.AnonymousLogin(nil)
|
||||
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
|
||||
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
||||
require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh"))
|
||||
require.Nil(t, session.Data(strings.NewReader(email)))
|
||||
|
@ -59,13 +61,12 @@ Content-Type: text/html; charset="UTF-8"
|
|||
<div dir="ltr"><br></div>
|
||||
|
||||
--000000000000bcf4a405d429f8d4--`
|
||||
_, backend := newTestBackend(t, func(m *message) error {
|
||||
require.Equal(t, "emailtest", m.Topic)
|
||||
require.Equal(t, "", m.Title) // We flipped message and body
|
||||
require.Equal(t, "This email has a subject but no body", m.Message)
|
||||
return nil
|
||||
_, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/emailtest", r.URL.Path)
|
||||
require.Equal(t, "", r.Header.Get("Title")) // We flipped message and body
|
||||
require.Equal(t, "This email has a subject but no body", readAll(t, r.Body))
|
||||
})
|
||||
session, _ := backend.AnonymousLogin(nil)
|
||||
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
|
||||
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
||||
require.Nil(t, session.Rcpt("ntfy-emailtest@ntfy.sh"))
|
||||
require.Nil(t, session.Data(strings.NewReader(email)))
|
||||
|
@ -81,14 +82,13 @@ Content-Type: text/plain; charset="UTF-8"
|
|||
|
||||
what's up
|
||||
`
|
||||
conf, backend := newTestBackend(t, func(m *message) error {
|
||||
require.Equal(t, "mytopic", m.Topic)
|
||||
require.Equal(t, "and one more", m.Title)
|
||||
require.Equal(t, "what's up", m.Message)
|
||||
return nil
|
||||
conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/mytopic", r.URL.Path)
|
||||
require.Equal(t, "and one more", r.Header.Get("Title"))
|
||||
require.Equal(t, "what's up", readAll(t, r.Body))
|
||||
})
|
||||
conf.SMTPServerAddrPrefix = ""
|
||||
session, _ := backend.AnonymousLogin(nil)
|
||||
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
|
||||
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
||||
require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
|
||||
require.Nil(t, session.Data(strings.NewReader(email)))
|
||||
|
@ -99,14 +99,13 @@ func TestSmtpBackend_Plaintext_No_ContentType(t *testing.T) {
|
|||
|
||||
what's up
|
||||
`
|
||||
conf, backend := newTestBackend(t, func(m *message) error {
|
||||
require.Equal(t, "mytopic", m.Topic)
|
||||
require.Equal(t, "Very short mail", m.Title)
|
||||
require.Equal(t, "what's up", m.Message)
|
||||
return nil
|
||||
conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/mytopic", r.URL.Path)
|
||||
require.Equal(t, "Very short mail", r.Header.Get("Title"))
|
||||
require.Equal(t, "what's up", readAll(t, r.Body))
|
||||
})
|
||||
conf.SMTPServerAddrPrefix = ""
|
||||
session, _ := backend.AnonymousLogin(nil)
|
||||
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
|
||||
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
||||
require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
|
||||
require.Nil(t, session.Data(strings.NewReader(email)))
|
||||
|
@ -121,11 +120,10 @@ Content-Type: text/plain; charset="UTF-8"
|
|||
|
||||
what's up
|
||||
`
|
||||
_, backend := newTestBackend(t, func(m *message) error {
|
||||
require.Equal(t, "Three santas 🎅🎅🎅", m.Title)
|
||||
return nil
|
||||
_, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "Three santas 🎅🎅🎅", r.Header.Get("Title"))
|
||||
})
|
||||
session, _ := backend.AnonymousLogin(nil)
|
||||
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
|
||||
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
||||
require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh"))
|
||||
require.Nil(t, session.Data(strings.NewReader(email)))
|
||||
|
@ -204,7 +202,7 @@ BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
|
|||
BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
|
||||
that should do it
|
||||
`
|
||||
conf, backend := newTestBackend(t, func(m *message) error {
|
||||
conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
expected := `you know this is a string.
|
||||
it's a long string.
|
||||
it's supposed to be longer than the max message length
|
||||
|
@ -266,13 +264,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
|||
......................................................................
|
||||
......................................................................
|
||||
and with BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
|
||||
BBBBBBBBBBBBBBBBBBBBBBBB`
|
||||
BBBBBBBBBBBBBBBBBBBBBBBBB`
|
||||
require.Equal(t, 4096, len(expected)) // Sanity check
|
||||
require.Equal(t, expected, m.Message)
|
||||
return nil
|
||||
require.Equal(t, expected, readAll(t, r.Body))
|
||||
})
|
||||
conf.SMTPServerAddrPrefix = ""
|
||||
session, _ := backend.AnonymousLogin(nil)
|
||||
session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
|
||||
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
||||
require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
|
||||
require.Nil(t, session.Data(strings.NewReader(email)))
|
||||
|
@ -288,21 +285,41 @@ Content-Type: text/SOMETHINGELSE
|
|||
|
||||
what's up
|
||||
`
|
||||
conf, backend := newTestBackend(t, func(m *message) error {
|
||||
return nil
|
||||
conf, backend := newTestBackend(t, func(http.ResponseWriter, *http.Request) {
|
||||
// Nothing.
|
||||
})
|
||||
conf.SMTPServerAddrPrefix = ""
|
||||
session, _ := backend.Login(nil, "user", "pass")
|
||||
session, _ := backend.Login(fakeConnState(t, "1.2.3.4"), "user", "pass")
|
||||
require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
|
||||
require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
|
||||
require.Equal(t, errUnsupportedContentType, session.Data(strings.NewReader(email)))
|
||||
}
|
||||
|
||||
func newTestBackend(t *testing.T, sub subscriber) (*Config, *smtpBackend) {
|
||||
func newTestBackend(t *testing.T, handler func(http.ResponseWriter, *http.Request)) (*Config, *smtpBackend) {
|
||||
conf := newTestConfig(t)
|
||||
conf.SMTPServerListen = ":25"
|
||||
conf.SMTPServerDomain = "ntfy.sh"
|
||||
conf.SMTPServerAddrPrefix = "ntfy-"
|
||||
backend := newMailBackend(conf, sub)
|
||||
backend := newMailBackend(conf, handler)
|
||||
return conf, backend
|
||||
}
|
||||
|
||||
func readAll(t *testing.T, rc io.ReadCloser) string {
|
||||
b, err := io.ReadAll(rc)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func fakeConnState(t *testing.T, remoteAddr string) *smtp.ConnectionState {
|
||||
ip, err := net.ResolveIPAddr("ip", remoteAddr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return &smtp.ConnectionState{
|
||||
Hostname: "myhostname",
|
||||
LocalAddr: ip,
|
||||
RemoteAddr: ip,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ type topic struct {
|
|||
}
|
||||
|
||||
// subscriber is a function that is called for every new message on a topic
|
||||
type subscriber func(msg *message) error
|
||||
type subscriber func(v *visitor, msg *message) error
|
||||
|
||||
// newTopic creates a new topic
|
||||
func newTopic(id string) *topic {
|
||||
|
@ -42,12 +42,12 @@ func (t *topic) Unsubscribe(id int) {
|
|||
}
|
||||
|
||||
// Publish asynchronously publishes to all subscribers
|
||||
func (t *topic) Publish(m *message) error {
|
||||
func (t *topic) Publish(v *visitor, m *message) error {
|
||||
go func() {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
for _, s := range t.subscribers {
|
||||
if err := s(m); err != nil {
|
||||
if err := s(v, m); err != nil {
|
||||
log.Printf("error publishing message to subscriber")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@ type visitor struct {
|
|||
emails *rate.Limiter
|
||||
subscriptions util.Limiter
|
||||
bandwidth util.Limiter
|
||||
firebase time.Time // Next allowed Firebase message
|
||||
seen time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
@ -48,14 +49,11 @@ func newVisitor(conf *Config, messageCache *messageCache, ip string) *visitor {
|
|||
emails: rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst),
|
||||
subscriptions: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
|
||||
bandwidth: util.NewBytesLimiter(conf.VisitorAttachmentDailyBandwidthLimit, 24*time.Hour),
|
||||
firebase: time.Unix(0, 0),
|
||||
seen: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (v *visitor) IP() string {
|
||||
return v.ip
|
||||
}
|
||||
|
||||
func (v *visitor) RequestAllowed() error {
|
||||
if !v.requests.Allow() {
|
||||
return errVisitorLimitReached
|
||||
|
@ -63,6 +61,21 @@ func (v *visitor) RequestAllowed() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (v *visitor) FirebaseAllowed() error {
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
if time.Now().Before(v.firebase) {
|
||||
return errVisitorLimitReached
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *visitor) FirebaseTemporarilyDeny() {
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
v.firebase = time.Now().Add(v.config.FirebaseQuotaLimitPenaltyDuration)
|
||||
}
|
||||
|
||||
func (v *visitor) EmailAllowed() error {
|
||||
if !v.emails.Allow() {
|
||||
return errVisitorLimitReached
|
||||
|
|
Loading…
Reference in a new issue