diff --git a/cmd/serve.go b/cmd/serve.go index 550e26a0..83a5b838 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -26,6 +26,9 @@ var flagsServe = []cli.Flag{ altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-user", EnvVars: []string{"NTFY_SMTP_USER"}, Usage: "SMTP user (if e-mail sending is enabled)"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-pass", EnvVars: []string{"NTFY_SMTP_PASS"}, Usage: "SMTP password (if e-mail sending is enabled)"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-from", EnvVars: []string{"NTFY_SMTP_FROM"}, Usage: "SMTP sender address (if e-mail sending is enabled)"}), + altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-server-listen", EnvVars: []string{"NTFY_SMTP_SERVER_LISTEN"}, Usage: "xxxxxxxxxx"}), + altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-server-domain", EnvVars: []string{"NTFY_SMTP_SERVER_DOMAIN"}, Usage: "xxxxxxxxxxx"}), + altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-server-addr-prefix", EnvVars: []string{"NTFY_SMTP_SERVER_ADDR_PREFIX"}, Usage: "xxxxxxxxxxx"}), altsrc.NewIntFlag(&cli.IntFlag{Name: "global-topic-limit", Aliases: []string{"T"}, EnvVars: []string{"NTFY_GLOBAL_TOPIC_LIMIT"}, Value: server.DefaultGlobalTopicLimit, Usage: "total number of topics allowed"}), altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-subscription-limit", EnvVars: []string{"NTFY_VISITOR_SUBSCRIPTION_LIMIT"}, Value: server.DefaultVisitorSubscriptionLimit, Usage: "number of subscriptions per visitor"}), altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-request-limit-burst", EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_BURST"}, Value: server.DefaultVisitorRequestLimitBurst, Usage: "initial limit of requests per visitor"}), @@ -68,10 +71,13 @@ func execServe(c *cli.Context) error { cacheDuration := c.Duration("cache-duration") keepaliveInterval := c.Duration("keepalive-interval") managerInterval := c.Duration("manager-interval") - smtpAddr := c.String("smtp-addr") - smtpUser := c.String("smtp-user") - smtpPass := c.String("smtp-pass") - smtpFrom := c.String("smtp-from") + smtpSenderAddr := c.String("smtp-addr") + smtpSenderUser := c.String("smtp-user") + smtpSenderPass := c.String("smtp-pass") + smtpSenderFrom := c.String("smtp-from") + smtpServerListen := c.String("smtp-server-listen") + smtpServerDomain := c.String("smtp-server-domain") + smtpServerAddrPrefix := c.String("smtp-server-addr-prefix") globalTopicLimit := c.Int("global-topic-limit") visitorSubscriptionLimit := c.Int("visitor-subscription-limit") visitorRequestLimitBurst := c.Int("visitor-request-limit-burst") @@ -95,7 +101,7 @@ func execServe(c *cli.Context) error { return errors.New("if set, certificate file must exist") } else if listenHTTPS != "" && (keyFile == "" || certFile == "") { return errors.New("if listen-https is set, both key-file and cert-file must be set") - } else if smtpAddr != "" && (baseURL == "" || smtpUser == "" || smtpPass == "" || smtpFrom == "") { + } else if smtpSenderAddr != "" && (baseURL == "" || smtpSenderUser == "" || smtpSenderPass == "" || smtpSenderFrom == "") { return errors.New("if smtp-addr is set, base-url, smtp-user, smtp-pass and smtp-from must also be set") } @@ -111,10 +117,13 @@ func execServe(c *cli.Context) error { conf.CacheDuration = cacheDuration conf.KeepaliveInterval = keepaliveInterval conf.ManagerInterval = managerInterval - conf.SMTPAddr = smtpAddr - conf.SMTPUser = smtpUser - conf.SMTPPass = smtpPass - conf.SMTPFrom = smtpFrom + conf.SMTPSenderAddr = smtpSenderAddr + conf.SMTPSenderUser = smtpSenderUser + conf.SMTPSenderPass = smtpSenderPass + conf.SMTPSenderFrom = smtpSenderFrom + conf.SMTPServerListen = smtpServerListen + conf.SMTPServerDomain = smtpServerDomain + conf.SMTPServerAddrPrefix = smtpServerAddrPrefix conf.GlobalTopicLimit = globalTopicLimit conf.VisitorSubscriptionLimit = visitorSubscriptionLimit conf.VisitorRequestLimitBurst = visitorRequestLimitBurst diff --git a/server/config.go b/server/config.go index b127d64d..b5422baa 100644 --- a/server/config.go +++ b/server/config.go @@ -45,10 +45,13 @@ type Config struct { ManagerInterval time.Duration AtSenderInterval time.Duration FirebaseKeepaliveInterval time.Duration - SMTPAddr string - SMTPUser string - SMTPPass string - SMTPFrom string + SMTPSenderAddr string + SMTPSenderUser string + SMTPSenderPass string + SMTPSenderFrom string + SMTPServerListen string + SMTPServerDomain string + SMTPServerAddrPrefix string MessageLimit int MinDelay time.Duration MaxDelay time.Duration diff --git a/server/mailserver.go b/server/mailserver.go deleted file mode 100644 index 08f2c193..00000000 --- a/server/mailserver.go +++ /dev/null @@ -1,102 +0,0 @@ -package server - -import ( - "bytes" - "errors" - "fmt" - "github.com/emersion/go-smtp" - "io" - "io/ioutil" - "log" - "net/http" - "net/http/httptest" - "net/mail" - "strings" - "sync" -) - -// mailBackend implements SMTP server methods. -type mailBackend struct { - s *Server -} - -func (b *mailBackend) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) { - return &Session{s: b.s}, nil -} - -func (b *mailBackend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) { - return &Session{s: b.s}, nil -} - -// Session is returned after EHLO. -type Session struct { - s *Server - from, to string - mu sync.Mutex -} - -func (s *Session) AuthPlain(username, password string) error { - return nil -} - -func (s *Session) Mail(from string, opts smtp.MailOptions) error { - s.mu.Lock() - defer s.mu.Unlock() - s.from = from - log.Println("Mail from:", from) - return nil -} - -func (s *Session) Rcpt(to string) error { - s.mu.Lock() - defer s.mu.Unlock() - s.to = to - log.Println("Rcpt to:", to) - return nil -} - -func (s *Session) Data(r io.Reader) error { - s.mu.Lock() - defer s.mu.Unlock() - b, err := ioutil.ReadAll(r) - if err != nil { - return err - } - - log.Println("Data:", string(b)) - msg, err := mail.ReadMessage(bytes.NewReader(b)) - if err != nil { - return err - } - body, err := io.ReadAll(msg.Body) - if err != nil { - return err - } - topic := strings.TrimSuffix(s.to, "@ntfy.sh") - url := fmt.Sprintf("%s/%s", s.s.config.BaseURL, topic) - req, err := http.NewRequest("PUT", url, bytes.NewReader(body)) - if err != nil { - return err - } - subject := msg.Header.Get("Subject") - if subject != "" { - req.Header.Set("Title", subject) - } - rr := httptest.NewRecorder() - s.s.handle(rr, req) - if rr.Code != http.StatusOK { - return errors.New("error: " + rr.Body.String()) - } - return nil -} - -func (s *Session) Reset() { - s.mu.Lock() - s.from = "" - s.to = "" - s.mu.Unlock() -} - -func (s *Session) Logout() error { - return nil -} diff --git a/server/server.go b/server/server.go index c2f8034b..6461085c 100644 --- a/server/server.go +++ b/server/server.go @@ -5,6 +5,7 @@ import ( "context" "embed" "encoding/json" + "errors" firebase "firebase.google.com/go" "firebase.google.com/go/messaging" "fmt" @@ -16,6 +17,7 @@ import ( "log" "net" "net/http" + "net/http/httptest" "regexp" "strconv" "strings" @@ -147,8 +149,8 @@ func New(conf *Config) (*Server, error) { } } var mailer mailer - if conf.SMTPAddr != "" { - mailer = &smtpMailer{config: conf} + if conf.SMTPSenderAddr != "" { + mailer = &smtpSender{config: conf} } cache, err := createCache(conf) if err != nil { @@ -239,9 +241,9 @@ func (s *Server) Run() error { errChan <- s.httpsServer.ListenAndServeTLS(s.config.CertFile, s.config.KeyFile) }() } - if true { + if s.config.SMTPServerListen != "" { go func() { - errChan <- s.mailserver() + errChan <- s.runMailserver() }() } s.mu.Unlock() @@ -729,15 +731,31 @@ func (s *Server) updateStatsAndPrune() { s.messages, len(s.topics), subscribers, messages, len(s.visitors)) } -func (s *Server) mailserver() error { - ms := smtp.NewServer(&mailBackend{s}) +func (s *Server) runMailserver() error { + sub := func(m *message) error { + url := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic) + req, err := http.NewRequest("PUT", url, strings.NewReader(m.Message)) + if err != nil { + return err + } + if m.Title != "" { + req.Header.Set("Title", m.Title) + } + rr := httptest.NewRecorder() + s.handle(rr, req) + if rr.Code != http.StatusOK { + return errors.New("error: " + rr.Body.String()) + } + return nil + } + ms := smtp.NewServer(newMailBackend(s.config, sub)) - ms.Addr = ":1025" - ms.Domain = "localhost" + ms.Addr = s.config.SMTPServerListen + ms.Domain = s.config.SMTPServerDomain ms.ReadTimeout = 10 * time.Second ms.WriteTimeout = 10 * time.Second - ms.MaxMessageBytes = 1024 * 1024 - ms.MaxRecipients = 50 + ms.MaxMessageBytes = 2 * s.config.MessageLimit + ms.MaxRecipients = 1 ms.AllowInsecureAuth = true log.Println("Starting server at", ms.Addr) diff --git a/server/server.yml b/server/server.yml index 8f8930f3..e6afefc4 100644 --- a/server/server.yml +++ b/server/server.yml @@ -59,6 +59,9 @@ # smtp-pass: # smtp-from: +# smtp-server-listen: +# smtp-server-addr: + # Interval in which keepalive messages are sent to the client. This is to prevent # intermediaries closing the connection for inactivity. # diff --git a/server/mailer.go b/server/smtp_sender.go similarity index 87% rename from server/mailer.go rename to server/smtp_sender.go index 22767212..15f004c1 100644 --- a/server/mailer.go +++ b/server/smtp_sender.go @@ -16,21 +16,21 @@ type mailer interface { Send(from, to string, m *message) error } -type smtpMailer struct { +type smtpSender struct { config *Config } -func (s *smtpMailer) Send(senderIP, to string, m *message) error { - host, _, err := net.SplitHostPort(s.config.SMTPAddr) +func (s *smtpSender) Send(senderIP, to string, m *message) error { + host, _, err := net.SplitHostPort(s.config.SMTPSenderAddr) if err != nil { return err } - message, err := formatMail(s.config.BaseURL, senderIP, s.config.SMTPFrom, to, m) + message, err := formatMail(s.config.BaseURL, senderIP, s.config.SMTPSenderFrom, to, m) if err != nil { return err } - auth := smtp.PlainAuth("", s.config.SMTPUser, s.config.SMTPPass, host) - return smtp.SendMail(s.config.SMTPAddr, auth, s.config.SMTPFrom, []string{to}, []byte(message)) + auth := smtp.PlainAuth("", s.config.SMTPSenderUser, s.config.SMTPSenderPass, host) + return smtp.SendMail(s.config.SMTPSenderAddr, auth, s.config.SMTPSenderFrom, []string{to}, []byte(message)) } func formatMail(baseURL, senderIP, from, to string, m *message) (string, error) { diff --git a/server/mailer_test.go b/server/smtp_sender_test.go similarity index 100% rename from server/mailer_test.go rename to server/smtp_sender_test.go diff --git a/server/smtp_server.go b/server/smtp_server.go new file mode 100644 index 00000000..f304dea8 --- /dev/null +++ b/server/smtp_server.go @@ -0,0 +1,108 @@ +package server + +import ( + "bytes" + "errors" + "github.com/emersion/go-smtp" + "io" + "io/ioutil" + "log" + "net/mail" + "strings" + "sync" +) + +// smtpBackend implements SMTP server methods. +type smtpBackend struct { + config *Config + sub subscriber +} + +func newMailBackend(conf *Config, sub subscriber) *smtpBackend { + return &smtpBackend{ + config: conf, + sub: sub, + } +} + +func (b *smtpBackend) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) { + return &smtpSession{config: b.config, sub: b.sub}, nil +} + +func (b *smtpBackend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) { + return &smtpSession{config: b.config, sub: b.sub}, nil +} + +// smtpSession is returned after EHLO. +type smtpSession struct { + config *Config + sub subscriber + from, to string + mu sync.Mutex +} + +func (s *smtpSession) AuthPlain(username, password string) error { + return nil +} + +func (s *smtpSession) Mail(from string, opts smtp.MailOptions) error { + s.mu.Lock() + defer s.mu.Unlock() + s.from = from + return nil +} + +func (s *smtpSession) Rcpt(to string) error { + s.mu.Lock() + defer s.mu.Unlock() + addressList, err := mail.ParseAddressList(to) + if err != nil { + return err + } else if len(addressList) != 1 { + return errors.New("only one recipient supported") + } else if !strings.HasSuffix(addressList[0].Address, "@"+s.config.SMTPServerDomain) { + return errors.New("invalid domain") + } else if s.config.SMTPServerAddrPrefix != "" && !strings.HasPrefix(addressList[0].Address, s.config.SMTPServerAddrPrefix) { + return errors.New("invalid address") + } + // FIXME check topic format + s.to = addressList[0].Address + return nil +} + +func (s *smtpSession) Data(r io.Reader) error { + s.mu.Lock() + defer s.mu.Unlock() + b, err := ioutil.ReadAll(r) + if err != nil { + return err + } + + log.Println("Data:", string(b)) + msg, err := mail.ReadMessage(bytes.NewReader(b)) + if err != nil { + return err + } + body, err := io.ReadAll(msg.Body) + if err != nil { + return err + } + topic := strings.TrimSuffix(s.to, "@"+s.config.SMTPServerDomain) + m := newDefaultMessage(topic, string(body)) + subject := msg.Header.Get("Subject") + if subject != "" { + m.Title = subject + } + return s.sub(m) +} + +func (s *smtpSession) Reset() { + s.mu.Lock() + s.from = "" + s.to = "" + s.mu.Unlock() +} + +func (s *smtpSession) Logout() error { + return nil +}