From f58c1e4c84e80f7fa1a2bea2e1b2bfae4240c7d0 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Thu, 1 Jun 2023 16:01:39 -0400 Subject: [PATCH] Fix previous fix --- client/client.go | 62 ++++++++++++++++++++----------------------- client/client_test.go | 2 +- cmd/app.go | 1 - cmd/publish.go | 4 --- cmd/subscribe.go | 12 ++++++--- 5 files changed, 38 insertions(+), 43 deletions(-) diff --git a/client/client.go b/client/client.go index e719e9ef..93cf7da5 100644 --- a/client/client.go +++ b/client/client.go @@ -11,23 +11,25 @@ import ( "heckel.io/ntfy/util" "io" "net/http" + "regexp" "strings" "sync" "time" ) -// Event type constants const ( - MessageEvent = "message" - KeepaliveEvent = "keepalive" - OpenEvent = "open" - PollRequestEvent = "poll_request" + // MessageEvent identifies a message event + MessageEvent = "message" ) const ( maxResponseBytes = 4096 ) +var ( + topicRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // Same as in server/server.go +) + // Client is the ntfy client that can be used to publish and subscribe to ntfy topics type Client struct { Messages chan *Message @@ -96,7 +98,10 @@ func (c *Client) Publish(topic, message string, options ...PublishOption) (*Mess // To pass title, priority and tags, check out WithTitle, WithPriority, WithTagsList, WithDelay, WithNoCache, // WithNoFirebase, and the generic WithHeader. func (c *Client) PublishReader(topic string, body io.Reader, options ...PublishOption) (*Message, error) { - topicURL := c.expandTopicURL(topic) + topicURL, err := c.expandTopicURL(topic) + if err != nil { + return nil, err + } req, err := http.NewRequest("POST", topicURL, body) if err != nil { return nil, err @@ -136,11 +141,14 @@ func (c *Client) PublishReader(topic string, body io.Reader, options ...PublishO // By default, all messages will be returned, but you can change this behavior using a SubscribeOption. // See WithSince, WithSinceAll, WithSinceUnixTime, WithScheduled, and the generic WithQueryParam. func (c *Client) Poll(topic string, options ...SubscribeOption) ([]*Message, error) { + topicURL, err := c.expandTopicURL(topic) + if err != nil { + return nil, err + } ctx := context.Background() messages := make([]*Message, 0) msgChan := make(chan *Message) errChan := make(chan error) - topicURL := c.expandTopicURL(topic) log.Debug("%s Polling from topic", util.ShortTopicURL(topicURL)) options = append(options, WithPoll()) go func() { @@ -169,15 +177,18 @@ func (c *Client) Poll(topic string, options ...SubscribeOption) ([]*Message, err // Example: // // c := client.New(client.NewConfig()) -// subscriptionID := c.Subscribe("mytopic") +// subscriptionID, _ := c.Subscribe("mytopic") // for m := range c.Messages { // fmt.Printf("New message: %s", m.Message) // } -func (c *Client) Subscribe(topic string, options ...SubscribeOption) string { +func (c *Client) Subscribe(topic string, options ...SubscribeOption) (string, error) { + topicURL, err := c.expandTopicURL(topic) + if err != nil { + return "", err + } c.mu.Lock() defer c.mu.Unlock() subscriptionID := util.RandomString(10) - topicURL := c.expandTopicURL(topic) log.Debug("%s Subscribing to topic", util.ShortTopicURL(topicURL)) ctx, cancel := context.WithCancel(context.Background()) c.subscriptions[subscriptionID] = &subscription{ @@ -186,7 +197,7 @@ func (c *Client) Subscribe(topic string, options ...SubscribeOption) string { cancel: cancel, } go handleSubscribeConnLoop(ctx, c.Messages, topicURL, subscriptionID, options...) - return subscriptionID + return subscriptionID, nil } // Unsubscribe unsubscribes from a topic that has been previously subscribed to using the unique @@ -202,31 +213,16 @@ func (c *Client) Unsubscribe(subscriptionID string) { sub.cancel() } -// UnsubscribeAll unsubscribes from a topic that has been previously subscribed with Subscribe. -// If there are multiple subscriptions matching the topic, all of them are unsubscribed from. -// -// A topic can be either a full URL (e.g. https://myhost.lan/mytopic), a short URL which is then prepended https:// -// (e.g. myhost.lan -> https://myhost.lan), or a short name which is expanded using the default host in the -// config (e.g. mytopic -> https://ntfy.sh/mytopic). -func (c *Client) UnsubscribeAll(topic string) { - c.mu.Lock() - defer c.mu.Unlock() - topicURL := c.expandTopicURL(topic) - for _, sub := range c.subscriptions { - if sub.topicURL == topicURL { - delete(c.subscriptions, sub.ID) - sub.cancel() - } - } -} - -func (c *Client) expandTopicURL(topic string) string { +func (c *Client) expandTopicURL(topic string) (string, error) { if strings.HasPrefix(topic, "http://") || strings.HasPrefix(topic, "https://") { - return topic + return topic, nil } else if strings.Contains(topic, "/") { - return fmt.Sprintf("https://%s", topic) + return fmt.Sprintf("https://%s", topic), nil } - return fmt.Sprintf("%s/%s", c.config.DefaultHost, topic) + if !topicRegex.MatchString(topic) { + return "", fmt.Errorf("invalid topic name: %s", topic) + } + return fmt.Sprintf("%s/%s", c.config.DefaultHost, topic), nil } func handleSubscribeConnLoop(ctx context.Context, msgChan chan *Message, topicURL, subcriptionID string, options ...SubscribeOption) { diff --git a/client/client_test.go b/client/client_test.go index a71ea5cb..f0b15a3f 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -21,7 +21,7 @@ func TestClient_Publish_Subscribe(t *testing.T) { defer test.StopServer(t, s, port) c := client.New(newTestConfig(port)) - subscriptionID := c.Subscribe("mytopic") + subscriptionID, _ := c.Subscribe("mytopic") time.Sleep(time.Second) msg, err := c.Publish("mytopic", "some message") diff --git a/cmd/app.go b/cmd/app.go index fd992633..edef5b47 100644 --- a/cmd/app.go +++ b/cmd/app.go @@ -29,7 +29,6 @@ var flagsDefault = []cli.Flag{ var ( logLevelOverrideRegex = regexp.MustCompile(`(?i)^([^=\s]+)(?:\s*=\s*(\S+))?\s*->\s*(TRACE|DEBUG|INFO|WARN|ERROR)$`) - topicRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // Same as in server/server.go ) // New creates a new CLI application diff --git a/cmd/publish.go b/cmd/publish.go index bce27e0f..0179f9fa 100644 --- a/cmd/publish.go +++ b/cmd/publish.go @@ -249,10 +249,6 @@ func parseTopicMessageCommand(c *cli.Context) (topic string, message string, com if c.String("message") != "" { message = c.String("message") } - if !topicRegex.MatchString(topic) { - err = fmt.Errorf("topic %s contains invalid characters", topic) - return - } return } diff --git a/cmd/subscribe.go b/cmd/subscribe.go index 81f5988c..c85c4686 100644 --- a/cmd/subscribe.go +++ b/cmd/subscribe.go @@ -108,8 +108,6 @@ func execSubscribe(c *cli.Context) error { // Checks if user != "" && token != "" { return errors.New("cannot set both --user and --token") - } else if !topicRegex.MatchString(topic) { - return fmt.Errorf("topic %s contains invalid characters", topic) } if !fromConfig { @@ -196,7 +194,10 @@ func doSubscribe(c *cli.Context, cl *client.Client, conf *client.Config, topic, topicOptions = append(topicOptions, auth) } - subscriptionID := cl.Subscribe(s.Topic, topicOptions...) + subscriptionID, err := cl.Subscribe(s.Topic, topicOptions...) + if err != nil { + return err + } if s.Command != "" { cmds[subscriptionID] = s.Command } else if conf.DefaultCommand != "" { @@ -206,7 +207,10 @@ func doSubscribe(c *cli.Context, cl *client.Client, conf *client.Config, topic, } } if topic != "" { - subscriptionID := cl.Subscribe(topic, options...) + subscriptionID, err := cl.Subscribe(topic, options...) + if err != nil { + return err + } cmds[subscriptionID] = command } for m := range cl.Messages {