diff --git a/cmd/serve.go b/cmd/serve.go index aa69c19f..cfaecda3 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -41,6 +41,7 @@ var flagsServe = []cli.Flag{ altsrc.NewDurationFlag(&cli.DurationFlag{Name: "keepalive-interval", Aliases: []string{"keepalive_interval", "k"}, EnvVars: []string{"NTFY_KEEPALIVE_INTERVAL"}, Value: server.DefaultKeepaliveInterval, Usage: "interval of keepalive messages"}), altsrc.NewDurationFlag(&cli.DurationFlag{Name: "manager-interval", Aliases: []string{"manager_interval", "m"}, EnvVars: []string{"NTFY_MANAGER_INTERVAL"}, Value: server.DefaultManagerInterval, Usage: "interval of for message pruning and stats printing"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "web-root", Aliases: []string{"web_root"}, EnvVars: []string{"NTFY_WEB_ROOT"}, Value: "app", Usage: "sets web root to landing page (home), web app (app) or disabled (disable)"}), + altsrc.NewStringFlag(&cli.StringFlag{Name: "forward-poll-url", Aliases: []string{"forward_poll_url"}, EnvVars: []string{"NTFY_FORWARD_POLL_URL"}, Value: "", Usage: ""}), altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-sender-addr", Aliases: []string{"smtp_sender_addr"}, EnvVars: []string{"NTFY_SMTP_SENDER_ADDR"}, Usage: "SMTP server address (host:port) for outgoing emails"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-sender-user", Aliases: []string{"smtp_sender_user"}, EnvVars: []string{"NTFY_SMTP_SENDER_USER"}, Usage: "SMTP user (if e-mail sending is enabled)"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-sender-pass", Aliases: []string{"smtp_sender_pass"}, EnvVars: []string{"NTFY_SMTP_SENDER_PASS"}, Usage: "SMTP password (if e-mail sending is enabled)"}), @@ -102,6 +103,7 @@ func execServe(c *cli.Context) error { keepaliveInterval := c.Duration("keepalive-interval") managerInterval := c.Duration("manager-interval") webRoot := c.String("web-root") + forwardPollURL := c.String("forward-poll-url") smtpSenderAddr := c.String("smtp-sender-addr") smtpSenderUser := c.String("smtp-sender-user") smtpSenderPass := c.String("smtp-sender-pass") @@ -147,6 +149,8 @@ func execServe(c *cli.Context) error { return errors.New("if set, auth-default-access must start set to 'read-write', 'read-only', 'write-only' or 'deny-all'") } else if !util.InStringList([]string{"app", "home", "disable"}, webRoot) { return errors.New("if set, web-root must be 'home' or 'app'") + } else if forwardPollURL != "" && !strings.HasPrefix(forwardPollURL, "http://") && !strings.HasPrefix(forwardPollURL, "https://") { + return errors.New("if set, forward-poll-url must start with http:// or https://") } webRootIsApp := webRoot == "app" @@ -215,6 +219,7 @@ func execServe(c *cli.Context) error { conf.KeepaliveInterval = keepaliveInterval conf.ManagerInterval = managerInterval conf.WebRootIsApp = webRootIsApp + conf.ForwardPollURL = forwardPollURL conf.SMTPSenderAddr = smtpSenderAddr conf.SMTPSenderUser = smtpSenderUser conf.SMTPSenderPass = smtpSenderPass diff --git a/server/config.go b/server/config.go index d36d5c66..2bb4b895 100644 --- a/server/config.go +++ b/server/config.go @@ -69,6 +69,7 @@ type Config struct { AtSenderInterval time.Duration FirebaseKeepaliveInterval time.Duration FirebasePollInterval time.Duration + ForwardPollURL string SMTPSenderAddr string SMTPSenderUser string SMTPSenderPass string diff --git a/server/server.go b/server/server.go index 55562d37..e3b738d9 100644 --- a/server/server.go +++ b/server/server.go @@ -3,6 +3,7 @@ package server import ( "bytes" "context" + "crypto/sha256" "embed" "encoding/base64" "encoding/json" @@ -93,6 +94,7 @@ const ( firebaseControlTopic = "~control" // See Android if changed firebasePollTopic = "~poll" // See iOS if changed emptyMessageBody = "triggered" // Used if message body is empty + newMessageBody = "New message" // Used in poll requests as generic message defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment encodingBase64 = "base64" ) @@ -422,6 +424,9 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito if err != nil { return err } + if m.PollID != "" { + m = newPollRequestMessage(t.ID, m.PollID) + } if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil { return err } @@ -448,6 +453,28 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito } }() } + if s.config.ForwardPollURL != "" { + go func() { + topicURL := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic) + topicHash := fmt.Sprintf("%x", sha256.Sum256([]byte(topicURL))) + forwardURL := fmt.Sprintf("%s/%s", s.config.ForwardPollURL, topicHash) + log.Printf("forwarding: topicURL %s, to upstream url %s", topicURL, forwardURL) + req, err := http.NewRequest("POST", forwardURL, strings.NewReader("")) + if err != nil { + log.Printf("[%s] FWD - Unable to forward poll request: %v", v.ip, err.Error()) + return + } + req.Header.Set("X-Poll-ID", m.ID) + response, err := http.DefaultClient.Do(req) + if err != nil { + log.Printf("[%s] FWD - Unable to forward poll request: %v", v.ip, err.Error()) + return + } else if response.StatusCode != http.StatusOK { + log.Printf("[%s] FWD - Unable to forward poll request, unexpected status: %d", v.ip, response.StatusCode) + return + } + }() + } if cache { if err := s.messageCache.AddMessage(m); err != nil { return err @@ -549,6 +576,12 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca firebase = false unifiedpush = true } + m.PollID = readParam(r, "x-poll-id", "poll-id", "poll") + if m.PollID != "" { + unifiedpush = false + cache = false + email = "" + } return cache, firebase, email, unifiedpush, nil } @@ -565,7 +598,9 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca // 5. curl -T file.txt ntfy.sh/mytopic // If file.txt is > message limit, treat it as an attachment func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser, unifiedpush bool) error { - if unifiedpush { + if m.Event == pollRequestEvent { + return nil // Ignore body + } else if unifiedpush { return s.handleBodyAsMessageAutoDetect(m, body) // Case 1 } else if m.Attachment != nil && m.Attachment.URL != "" { return s.handleBodyAsTextMessage(m, body) // Case 2 @@ -710,6 +745,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset! if poll { + log.Printf("polling %#v", r.URL) return s.sendOldMessages(topics, since, scheduled, sub) } subscriberIDs := make([]int, 0) diff --git a/server/server_firebase.go b/server/server_firebase.go index ad0da0e2..373cb458 100644 --- a/server/server_firebase.go +++ b/server/server_firebase.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "log" "strings" firebase "firebase.google.com/go" @@ -64,6 +65,7 @@ func createFirebaseSubscriber(credentialsFile string, auther auth.Auther) (subsc if err != nil { return err } + log.Printf("Sending %#v %#v", m, fbm) _, err = msg.Send(context.Background(), fbm) return err }, nil @@ -98,6 +100,31 @@ func toFirebaseMessage(m *message, auther auth.Auther) (*messaging.Message, erro CustomData: apnsData, }, } + case pollRequestEvent: + data = map[string]string{ + "id": m.ID, + "time": fmt.Sprintf("%d", m.Time), + "event": m.Event, + "topic": m.Topic, + "message": m.Message, + "poll_id": m.PollID, + } + apnsData := make(map[string]interface{}) + for k, v := range data { + apnsData[k] = v + } + apnsConfig = &messaging.APNSConfig{ + Payload: &messaging.APNSPayload{ + CustomData: apnsData, + Aps: &messaging.Aps{ + MutableContent: true, + Alert: &messaging.ApsAlert{ + Title: m.Title, + Body: maybeTruncateAPNSBodyMessage(m.Message), + }, + }, + }, + } case messageEvent: allowForward := true if auther != nil { diff --git a/server/types.go b/server/types.go index 3f6fcdbd..6a69338c 100644 --- a/server/types.go +++ b/server/types.go @@ -24,13 +24,14 @@ type message struct { Time int64 `json:"time"` // Unix time in seconds Event string `json:"event"` // One of the above Topic string `json:"topic"` + Title string `json:"title,omitempty"` + Message string `json:"message,omitempty"` Priority int `json:"priority,omitempty"` Tags []string `json:"tags,omitempty"` Click string `json:"click,omitempty"` Actions []*action `json:"actions,omitempty"` Attachment *attachment `json:"attachment,omitempty"` - Title string `json:"title,omitempty"` - Message string `json:"message,omitempty"` + PollID string `json:"poll_id,omitempty"` Encoding string `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes } @@ -84,14 +85,11 @@ type messageEncoder func(msg *message) (string, error) // newMessage creates a new message with the current timestamp func newMessage(event, topic, msg string) *message { return &message{ - ID: util.RandomString(messageIDLength), - Time: time.Now().Unix(), - Event: event, - Topic: topic, - Priority: 0, - Tags: nil, - Title: "", - Message: msg, + ID: util.RandomString(messageIDLength), + Time: time.Now().Unix(), + Event: event, + Topic: topic, + Message: msg, } } @@ -110,6 +108,13 @@ func newDefaultMessage(topic, msg string) *message { return newMessage(messageEvent, topic, msg) } +// newPollRequestMessage is a convenience method to create a poll request message +func newPollRequestMessage(topic, pollID string) *message { + m := newMessage(pollRequestEvent, topic, newMessageBody) + m.PollID = pollID + return m +} + func validMessageID(s string) bool { return util.ValidRandomString(s, messageIDLength) }