diff --git a/client/client.go b/client/client.go index 0a1022c2..455a9aa6 100644 --- a/client/client.go +++ b/client/client.go @@ -6,6 +6,7 @@ import ( "context" "encoding/json" "fmt" + "heckel.io/ntfy/util" "io" "log" "net/http" @@ -39,16 +40,21 @@ type Message struct { Event string Time int64 Topic string - TopicURL string Message string Title string Priority int Tags []string - Raw string + + // Additional fields + TopicURL string + SubscriptionID string + Raw string } type subscription struct { - cancel context.CancelFunc + ID string + topicURL string + cancel context.CancelFunc } // New creates a new Client using a given Config @@ -88,7 +94,7 @@ func (c *Client) Publish(topic, message string, options ...PublishOption) (*Mess if err != nil { return nil, err } - m, err := toMessage(string(b), topicURL) + m, err := toMessage(string(b), topicURL, "") if err != nil { return nil, err } @@ -111,7 +117,7 @@ func (c *Client) Poll(topic string, options ...SubscribeOption) ([]*Message, err errChan := make(chan error) topicURL := c.expandTopicURL(topic) go func() { - err := performSubscribeRequest(ctx, msgChan, topicURL, options...) + err := performSubscribeRequest(ctx, msgChan, topicURL, "", options...) close(msgChan) errChan <- err }() @@ -131,39 +137,58 @@ func (c *Client) Poll(topic string, options ...SubscribeOption) ([]*Message, err // By default, only new messages will be returned, but you can change this behavior using a SubscribeOption. // See WithSince, WithSinceAll, WithSinceUnixTime, WithScheduled, and the generic WithQueryParam. // +// The method returns a unique subscriptionID that can be used in Unsubscribe. +// // Example: // c := client.New(client.NewConfig()) -// c.Subscribe("mytopic") +// subscriptionID := c.Subscribe("mytopic") // for m := range c.Messages { // fmt.Printf("New message: %s", m.Message) // } func (c *Client) Subscribe(topic string, options ...SubscribeOption) string { c.mu.Lock() defer c.mu.Unlock() + subscriptionID := util.RandomString(10) topicURL := c.expandTopicURL(topic) - if _, ok := c.subscriptions[topicURL]; ok { - return topicURL - } ctx, cancel := context.WithCancel(context.Background()) - c.subscriptions[topicURL] = &subscription{cancel} - go handleSubscribeConnLoop(ctx, c.Messages, topicURL, options...) - return topicURL + c.subscriptions[subscriptionID] = &subscription{ + ID: subscriptionID, + topicURL: topicURL, + cancel: cancel, + } + go handleSubscribeConnLoop(ctx, c.Messages, topicURL, subscriptionID, options...) + return subscriptionID } -// Unsubscribe unsubscribes from a topic that has been previously subscribed with Subscribe. +// Unsubscribe unsubscribes from a topic that has been previously subscribed to using the unique +// subscriptionID returned in Subscribe. +func (c *Client) Unsubscribe(subscriptionID string) { + c.mu.Lock() + defer c.mu.Unlock() + sub, ok := c.subscriptions[subscriptionID] + if !ok { + return + } + delete(c.subscriptions, subscriptionID) + sub.cancel() +} + +// UnsubscribeAll unsubscribes from a topic that has been previously subscribed with Subscribe. +// If there are multiple subscriptions matching the topic, all of them are unsubscribed from. // // A topic can be either a full URL (e.g. https://myhost.lan/mytopic), a short URL which is then prepended https:// // (e.g. myhost.lan -> https://myhost.lan), or a short name which is expanded using the default host in the // config (e.g. mytopic -> https://ntfy.sh/mytopic). -func (c *Client) Unsubscribe(topic string) { +func (c *Client) UnsubscribeAll(topic string) { c.mu.Lock() defer c.mu.Unlock() topicURL := c.expandTopicURL(topic) - sub, ok := c.subscriptions[topicURL] - if !ok { - return + for _, sub := range c.subscriptions { + if sub.topicURL == topicURL { + delete(c.subscriptions, sub.ID) + sub.cancel() + } } - sub.cancel() } func (c *Client) expandTopicURL(topic string) string { @@ -175,9 +200,11 @@ func (c *Client) expandTopicURL(topic string) string { return fmt.Sprintf("%s/%s", c.config.DefaultHost, topic) } -func handleSubscribeConnLoop(ctx context.Context, msgChan chan *Message, topicURL string, options ...SubscribeOption) { +func handleSubscribeConnLoop(ctx context.Context, msgChan chan *Message, topicURL, subcriptionID string, options ...SubscribeOption) { for { - if err := performSubscribeRequest(ctx, msgChan, topicURL, options...); err != nil { + // TODO The retry logic is crude and may lose messages. It should record the last message like the + // Android client, use since=, and do incremental backoff too + if err := performSubscribeRequest(ctx, msgChan, topicURL, subcriptionID, options...); err != nil { log.Printf("Connection to %s failed: %s", topicURL, err.Error()) } select { @@ -189,7 +216,7 @@ func handleSubscribeConnLoop(ctx context.Context, msgChan chan *Message, topicUR } } -func performSubscribeRequest(ctx context.Context, msgChan chan *Message, topicURL string, options ...SubscribeOption) error { +func performSubscribeRequest(ctx context.Context, msgChan chan *Message, topicURL string, subscriptionID string, options ...SubscribeOption) error { req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s/json", topicURL), nil) if err != nil { return err @@ -206,7 +233,7 @@ func performSubscribeRequest(ctx context.Context, msgChan chan *Message, topicUR defer resp.Body.Close() scanner := bufio.NewScanner(resp.Body) for scanner.Scan() { - m, err := toMessage(scanner.Text(), topicURL) + m, err := toMessage(scanner.Text(), topicURL, subscriptionID) if err != nil { return err } @@ -215,12 +242,13 @@ func performSubscribeRequest(ctx context.Context, msgChan chan *Message, topicUR return nil } -func toMessage(s, topicURL string) (*Message, error) { +func toMessage(s, topicURL, subscriptionID string) (*Message, error) { var m *Message if err := json.NewDecoder(strings.NewReader(s)).Decode(&m); err != nil { return nil, err } m.TopicURL = topicURL + m.SubscriptionID = subscriptionID m.Raw = s return m, nil } diff --git a/client/config.go b/client/config.go index 8c196118..c2a40d88 100644 --- a/client/config.go +++ b/client/config.go @@ -9,9 +9,9 @@ const ( type Config struct { DefaultHost string `yaml:"default-host"` Subscribe []struct { - Topic string `yaml:"topic"` - Command string `yaml:"command"` - // If []map[string]string TODO This would be cool + Topic string `yaml:"topic"` + Command string `yaml:"command"` + If map[string]string `yaml:"if"` } `yaml:"subscribe"` } diff --git a/client/options.go b/client/options.go index ee8cce5a..dd180f79 100644 --- a/client/options.go +++ b/client/options.go @@ -88,6 +88,32 @@ func WithScheduled() SubscribeOption { return WithQueryParam("scheduled", "1") } +// WithFilter is a generic subscribe option meant to be used to filter for certain messages only +func WithFilter(param, value string) SubscribeOption { + return WithQueryParam(param, value) +} + +// WithMessageFilter instructs the server to only return messages that match the exact message +func WithMessageFilter(message string) SubscribeOption { + return WithQueryParam("message", message) +} + +// WithTitleFilter instructs the server to only return messages with a title that match the exact string +func WithTitleFilter(title string) SubscribeOption { + return WithQueryParam("title", title) +} + +// WithPriorityFilter instructs the server to only return messages with the matching priority. Not that messages +// without priority also implicitly match priority 3. +func WithPriorityFilter(priority int) SubscribeOption { + return WithQueryParam("priority", fmt.Sprintf("%d", priority)) +} + +// WithTagsFilter instructs the server to only return messages that contain all of the given tags +func WithTagsFilter(tags []string) SubscribeOption { + return WithQueryParam("tags", strings.Join(tags, ",")) +} + // WithHeader is a generic option to add headers to a request func WithHeader(header, value string) RequestOption { return func(r *http.Request) error { diff --git a/cmd/subscribe.go b/cmd/subscribe.go index a45b9c2e..0ba247fc 100644 --- a/cmd/subscribe.go +++ b/cmd/subscribe.go @@ -135,17 +135,21 @@ func doPollSingle(c *cli.Context, cl *client.Client, topic, command string, opti } func doSubscribe(c *cli.Context, cl *client.Client, conf *client.Config, topic, command string, options ...client.SubscribeOption) error { - commands := make(map[string]string) - for _, s := range conf.Subscribe { // May be nil - topicURL := cl.Subscribe(s.Topic, options...) - commands[topicURL] = s.Command + commands := make(map[string]string) // Subscription ID -> command + for _, s := range conf.Subscribe { // May be nil + topicOptions := append(make([]client.SubscribeOption, 0), options...) + for filter, value := range s.If { + topicOptions = append(topicOptions, client.WithFilter(filter, value)) + } + subscriptionID := cl.Subscribe(s.Topic, topicOptions...) + commands[subscriptionID] = s.Command } if topic != "" { - topicURL := cl.Subscribe(topic, options...) - commands[topicURL] = command + subscriptionID := cl.Subscribe(topic, options...) + commands[subscriptionID] = command } for m := range cl.Messages { - command, ok := commands[m.TopicURL] + command, ok := commands[m.SubscriptionID] if !ok { continue } diff --git a/server/server.go b/server/server.go index d4b9f939..38198bdb 100644 --- a/server/server.go +++ b/server/server.go @@ -334,7 +334,7 @@ func (s *Server) parseParams(r *http.Request, m *message) (cache bool, firebase tagsStr := readParam(r, "x-tags", "tag", "tags", "ta") if tagsStr != "" { m.Tags = make([]string, 0) - for _, s := range strings.Split(tagsStr, ",") { + for _, s := range util.SplitNoEmpty(tagsStr, ",") { m.Tags = append(m.Tags, strings.TrimSpace(s)) } } @@ -413,7 +413,7 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi } defer v.RemoveSubscription() topicsStr := strings.TrimSuffix(r.URL.Path[1:], "/"+format) // Hack - topicIDs := strings.Split(topicsStr, ",") + topicIDs := util.SplitNoEmpty(topicsStr, ",") topics, err := s.topicsFromIDs(topicIDs...) if err != nil { return err @@ -425,13 +425,20 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi var wlock sync.Mutex poll := r.URL.Query().Has("poll") scheduled := r.URL.Query().Has("scheduled") || r.URL.Query().Has("sched") + messageFilter, titleFilter, priorityFilter, tagsFilter, err := parseQueryFilters(r) + if err != nil { + return err + } sub := func(msg *message) error { - wlock.Lock() - defer wlock.Unlock() + if !passesQueryFilter(msg, messageFilter, titleFilter, priorityFilter, tagsFilter) { + return nil + } m, err := encoder(msg) if err != nil { return err } + wlock.Lock() + defer wlock.Unlock() if _, err := w.Write([]byte(m)); err != nil { return err } @@ -473,6 +480,34 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi } } +func parseQueryFilters(r *http.Request) (messageFilter string, titleFilter string, priorityFilter int, tagsFilter []string, err error) { + messageFilter = r.URL.Query().Get("message") + titleFilter = r.URL.Query().Get("title") + tagsFilter = util.SplitNoEmpty(r.URL.Query().Get("tags"), ",") + priorityFilter, err = util.ParsePriority(r.URL.Query().Get("priority")) + return +} + +func passesQueryFilter(msg *message, messageFilter string, titleFilter string, priorityFilter int, tagsFilter []string) bool { + if messageFilter != "" && msg.Message != messageFilter { + log.Printf("1") + return false + } + if titleFilter != "" && msg.Title != titleFilter { + log.Printf("2") + return false + } + if priorityFilter > 0 && (msg.Priority != priorityFilter || (msg.Priority == 0 && priorityFilter != 3)) { + log.Printf("3") + return false + } + if len(tagsFilter) > 0 && !util.InStringListAll(msg.Tags, tagsFilter) { + log.Printf("4") + return false + } + return true +} + func (s *Server) sendOldMessages(topics []*topic, since sinceTime, scheduled bool, sub subscriber) error { if since.IsNone() { return nil diff --git a/util/util.go b/util/util.go index 010c0f58..8243bccb 100644 --- a/util/util.go +++ b/util/util.go @@ -37,6 +37,30 @@ func InStringList(haystack []string, needle string) bool { return false } +// InStringListAll returns true if all needles are contained in haystack +func InStringListAll(haystack []string, needles []string) bool { + matches := 0 + for _, s := range haystack { + for _, needle := range needles { + if s == needle { + matches++ + } + } + } + return matches == len(needles) +} + +// SplitNoEmpty splits a string using strings.Split, but filters out empty strings +func SplitNoEmpty(s string, sep string) []string { + res := make([]string, 0) + for _, r := range strings.Split(s, sep) { + if r != "" { + res = append(res, r) + } + } + return res +} + // RandomString returns a random string with a given length func RandomString(length int) string { randomMutex.Lock() // Who would have thought that random.Intn() is not thread-safe?! diff --git a/util/util_test.go b/util/util_test.go index 06c2af9b..50a3a689 100644 --- a/util/util_test.go +++ b/util/util_test.go @@ -56,6 +56,19 @@ func TestInStringList(t *testing.T) { require.False(t, InStringList(s, "three")) } +func TestInStringListAll(t *testing.T) { + s := []string{"one", "two", "three", "four"} + require.True(t, InStringListAll(s, []string{"two", "four"})) + require.False(t, InStringListAll(s, []string{"three", "five"})) +} + +func TestSplitNoEmpty(t *testing.T) { + require.Equal(t, []string{}, SplitNoEmpty("", ",")) + require.Equal(t, []string{}, SplitNoEmpty(",,,", ",")) + require.Equal(t, []string{"tag1", "tag2"}, SplitNoEmpty("tag1,tag2", ",")) + require.Equal(t, []string{"tag1", "tag2"}, SplitNoEmpty("tag1,tag2,", ",")) +} + func TestExpandHome_WithTilde(t *testing.T) { require.Equal(t, os.Getenv("HOME")+"/this/is/a/path", ExpandHome("~/this/is/a/path")) }