1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2024-11-27 13:44:59 +01:00
This commit is contained in:
Philipp Heckel 2022-07-18 14:37:51 -04:00
parent 09cb1482b4
commit 466c9874a8
6 changed files with 114 additions and 76 deletions

View file

@ -3,14 +3,19 @@ package client
import ( import (
"bufio" "bufio"
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/crypto"
"heckel.io/ntfy/log" "heckel.io/ntfy/log"
"heckel.io/ntfy/util" "heckel.io/ntfy/util"
"io" "io"
"mime/multipart"
"net/http" "net/http"
"net/http/httptest"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -25,7 +30,8 @@ const (
) )
const ( const (
maxResponseBytes = 4096 maxResponseBytes = 4096
encryptedMessageBytesLimit = 100 * 1024 * 1024 // 100 MB
) )
// Client is the ntfy client that can be used to publish and subscribe to ntfy topics // Client is the ntfy client that can be used to publish and subscribe to ntfy topics
@ -95,7 +101,7 @@ func (c *Client) Publish(topic, message string, options ...PublishOption) (*Mess
// To pass title, priority and tags, check out WithTitle, WithPriority, WithTagsList, WithDelay, WithNoCache, // To pass title, priority and tags, check out WithTitle, WithPriority, WithTagsList, WithDelay, WithNoCache,
// WithNoFirebase, and the generic WithHeader. // WithNoFirebase, and the generic WithHeader.
func (c *Client) PublishReader(topic string, body io.Reader, options ...PublishOption) (*Message, error) { func (c *Client) PublishReader(topic string, body io.Reader, options ...PublishOption) (*Message, error) {
topicURL := c.expandTopicURL(topic) topicURL := util.ExpandTopicURL(topic, c.config.DefaultHost)
req, _ := http.NewRequest("POST", topicURL, body) req, _ := http.NewRequest("POST", topicURL, body)
for _, option := range options { for _, option := range options {
if err := option(req); err != nil { if err := option(req); err != nil {
@ -122,6 +128,59 @@ func (c *Client) PublishReader(topic string, body io.Reader, options ...PublishO
return m, nil return m, nil
} }
func (c *Client) PublishEncryptedReader(topic string, body io.Reader, password string, options ...PublishOption) (*Message, error) {
topicURL := util.ExpandTopicURL(topic, c.config.DefaultHost)
key := crypto.DeriveKey(password, topicURL)
peaked, err := util.PeekLimit(io.NopCloser(body), encryptedMessageBytesLimit)
if err != nil {
return nil, err
}
ciphertext, err := crypto.Encrypt(peaked.PeekedBytes, key)
if err != nil {
return nil, err
}
var b bytes.Buffer
body = strings.NewReader(ciphertext)
w := multipart.NewWriter(&b)
for _, part := range parts {
mw, _ := w.CreateFormField(part.key)
_, err := io.Copy(mw, strings.NewReader(part.value))
require.Nil(t, err)
}
require.Nil(t, w.Close())
rr := httptest.NewRecorder()
req, err := http.NewRequest(method, url, &b)
if err != nil {
t.Fatal(err)
}
req, _ := http.NewRequest("POST", topicURL, body)
req.Header.Set("X-Encoding", "jwe")
for _, option := range options {
if err := option(req); err != nil {
return nil, err
}
}
log.Debug("%s Publishing message with headers %s", util.ShortTopicURL(topicURL), req.Header)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
b, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, errors.New(strings.TrimSpace(string(b)))
}
m, err := toMessage(string(b), topicURL, "")
if err != nil {
return nil, err
}
return m, nil
}
// Poll queries a topic for all (or a limited set) of messages. Unlike Subscribe, this method only polls for // Poll queries a topic for all (or a limited set) of messages. Unlike Subscribe, this method only polls for
// messages and does not subscribe to messages that arrive after this call. // messages and does not subscribe to messages that arrive after this call.
// //
@ -136,7 +195,7 @@ func (c *Client) Poll(topic string, options ...SubscribeOption) ([]*Message, err
messages := make([]*Message, 0) messages := make([]*Message, 0)
msgChan := make(chan *Message) msgChan := make(chan *Message)
errChan := make(chan error) errChan := make(chan error)
topicURL := c.expandTopicURL(topic) topicURL := util.ExpandTopicURL(topic, c.config.DefaultHost)
log.Debug("%s Polling from topic", util.ShortTopicURL(topicURL)) log.Debug("%s Polling from topic", util.ShortTopicURL(topicURL))
options = append(options, WithPoll()) options = append(options, WithPoll())
go func() { go func() {
@ -172,7 +231,7 @@ func (c *Client) Subscribe(topic string, options ...SubscribeOption) string {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
subscriptionID := util.RandomString(10) subscriptionID := util.RandomString(10)
topicURL := c.expandTopicURL(topic) topicURL := util.ExpandTopicURL(topic, c.config.DefaultHost)
log.Debug("%s Subscribing to topic", util.ShortTopicURL(topicURL)) log.Debug("%s Subscribing to topic", util.ShortTopicURL(topicURL))
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
c.subscriptions[subscriptionID] = &subscription{ c.subscriptions[subscriptionID] = &subscription{
@ -206,7 +265,7 @@ func (c *Client) Unsubscribe(subscriptionID string) {
func (c *Client) UnsubscribeAll(topic string) { func (c *Client) UnsubscribeAll(topic string) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
topicURL := c.expandTopicURL(topic) topicURL := util.ExpandTopicURL(topic, c.config.DefaultHost)
for _, sub := range c.subscriptions { for _, sub := range c.subscriptions {
if sub.topicURL == topicURL { if sub.topicURL == topicURL {
delete(c.subscriptions, sub.ID) delete(c.subscriptions, sub.ID)
@ -215,15 +274,6 @@ func (c *Client) UnsubscribeAll(topic string) {
} }
} }
func (c *Client) expandTopicURL(topic string) string {
if strings.HasPrefix(topic, "http://") || strings.HasPrefix(topic, "https://") {
return topic
} else if strings.Contains(topic, "/") {
return fmt.Sprintf("https://%s", topic)
}
return fmt.Sprintf("%s/%s", c.config.DefaultHost, topic)
}
func handleSubscribeConnLoop(ctx context.Context, msgChan chan *Message, topicURL, subcriptionID string, options ...SubscribeOption) { func handleSubscribeConnLoop(ctx context.Context, msgChan chan *Message, topicURL, subcriptionID string, options ...SubscribeOption) {
for { for {
// TODO The retry logic is crude and may lose messages. It should record the last message like the // TODO The retry logic is crude and may lose messages. It should record the last message like the

View file

@ -92,8 +92,9 @@ func WithNoFirebase() PublishOption {
return WithHeader("X-Firebase", "no") return WithHeader("X-Firebase", "no")
} }
// WithEncrypted sets the encoding header to "jwe"
func WithEncrypted() PublishOption { func WithEncrypted() PublishOption {
return WithHeader("X-Encryption", "jwe") return WithHeader("X-Encoding", "jwe")
} }
// WithSince limits the number of messages returned from the server. The parameter since can be a Unix // WithSince limits the number of messages returned from the server. The parameter since can be a Unix

View file

@ -5,8 +5,8 @@ import (
"fmt" "fmt"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
"heckel.io/ntfy/client" "heckel.io/ntfy/client"
"heckel.io/ntfy/crypto"
"heckel.io/ntfy/log" "heckel.io/ntfy/log"
"heckel.io/ntfy/server"
"heckel.io/ntfy/util" "heckel.io/ntfy/util"
"io" "io"
"os" "os"
@ -16,10 +16,6 @@ import (
"time" "time"
) )
const (
encryptedMessageBytesLimit = 100 * 1024 * 1024 // 100 MB
)
func init() { func init() {
commands = append(commands, cmdPublish) commands = append(commands, cmdPublish)
} }
@ -110,34 +106,27 @@ func execPublish(c *cli.Context) error {
if err != nil { if err != nil {
return err return err
} }
pm := &server.PublishMessage{
Topic: topic,
Title: title,
Message: message,
Tags: util.SplitNoEmpty(tags, ","),
Click: click,
Actions: nil,
Attach: attach,
Filename: filename,
Email: email,
Delay: delay,
}
var options []client.PublishOption var options []client.PublishOption
if title != "" { p, err := util.ParsePriority(priority)
options = append(options, client.WithTitle(title)) if err != nil {
} return err
if priority != "" {
options = append(options, client.WithPriority(priority))
}
if tags != "" {
options = append(options, client.WithTagsList(tags))
}
if delay != "" {
options = append(options, client.WithDelay(delay))
}
if click != "" {
options = append(options, client.WithClick(click))
} }
pm.Priority = p
if actions != "" { if actions != "" {
options = append(options, client.WithActions(strings.ReplaceAll(actions, "\n", " "))) options = append(options, client.WithActions(strings.ReplaceAll(actions, "\n", " ")))
} }
if attach != "" {
options = append(options, client.WithAttach(attach))
}
if filename != "" {
options = append(options, client.WithFilename(filename))
}
if email != "" {
options = append(options, client.WithEmail(email))
}
if noCache { if noCache {
options = append(options, client.WithNoCache()) options = append(options, client.WithNoCache())
} }
@ -165,15 +154,15 @@ func execPublish(c *cli.Context) error {
newMessage, err := waitForProcess(pid) newMessage, err := waitForProcess(pid)
if err != nil { if err != nil {
return err return err
} else if message == "" { } else if pm.Message == "" {
message = newMessage pm.Message = newMessage
} }
} else if len(command) > 0 { } else if len(command) > 0 {
newMessage, err := runAndWaitForCommand(command) newMessage, err := runAndWaitForCommand(command)
if err != nil { if err != nil {
return err return err
} else if message == "" { } else if pm.Message == "" {
message = newMessage pm.Message = newMessage
} }
} }
var body io.Reader var body io.Reader
@ -198,24 +187,16 @@ func execPublish(c *cli.Context) error {
} }
} }
} }
if password != "" { var m *client.Message
topicURL := expandTopicURL(topic, conf.DefaultHost)
key := crypto.DeriveKey(password, topicURL)
peaked, err := util.PeekLimit(io.NopCloser(body), encryptedMessageBytesLimit)
if err != nil {
return err
}
ciphertext, err := crypto.Encrypt(peaked.PeekedBytes, key)
if err != nil {
return err
}
body = strings.NewReader(ciphertext)
options = append(options, client.WithEncrypted())
}
cl := client.New(conf) cl := client.New(conf)
m, err := cl.PublishReader(topic, body, options...) if password != "" {
if err != nil { if m, err = cl.PublishEncryptedReader(topic, m, password, options...); err != nil {
return err return err
}
} else {
if m, err = cl.PublishReader(topic, m, options...); err != nil {
return err
}
} }
if !quiet { if !quiet {
fmt.Fprintln(c.App.Writer, strings.TrimSpace(m.Raw)) fmt.Fprintln(c.App.Writer, strings.TrimSpace(m.Raw))
@ -223,15 +204,6 @@ func execPublish(c *cli.Context) error {
return nil return nil
} }
func expandTopicURL(topic, defaultHost string) string {
if strings.HasPrefix(topic, "http://") || strings.HasPrefix(topic, "https://") {
return topic
} else if strings.Contains(topic, "/") {
return fmt.Sprintf("https://%s", topic)
}
return fmt.Sprintf("%s/%s", defaultHost, topic)
}
// parseTopicMessageCommand reads the topic and the remaining arguments from the context. // parseTopicMessageCommand reads the topic and the remaining arguments from the context.
// //
// There are a few cases to consider: // There are a few cases to consider:

View file

@ -447,6 +447,11 @@ func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error {
return writeMatrixDiscoveryResponse(w) return writeMatrixDiscoveryResponse(w)
} }
type inputMessage struct {
message
cache bool
}
func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) { func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) {
t, err := s.topicFromPath(r.URL.Path) t, err := s.topicFromPath(r.URL.Path)
if err != nil { if err != nil {
@ -1367,7 +1372,7 @@ func (s *Server) transformBodyJSON(next handleFunc) handleFunc {
return err return err
} }
defer r.Body.Close() defer r.Body.Close()
var m publishMessage var m PublishMessage
if err := json.NewDecoder(body).Decode(&m); err != nil { if err := json.NewDecoder(body).Decode(&m); err != nil {
return errHTTPBadRequestJSONInvalid return errHTTPBadRequestJSONInvalid
} }

View file

@ -64,8 +64,8 @@ func newAction() *action {
} }
} }
// publishMessage is used as input when publishing as JSON // PublishMessage is used as input when publishing as JSON
type publishMessage struct { type PublishMessage struct {
Topic string `json:"topic"` Topic string `json:"topic"`
Title string `json:"title"` Title string `json:"title"`
Message string `json:"message"` Message string `json:"message"`

View file

@ -172,6 +172,16 @@ func ShortTopicURL(s string) string {
return strings.TrimPrefix(strings.TrimPrefix(s, "https://"), "http://") return strings.TrimPrefix(strings.TrimPrefix(s, "https://"), "http://")
} }
// ExpandTopicURL expands a topic to a fully qualified URL, e.g. "mytopic" -> "https://ntfy.sh/mytopic"
func ExpandTopicURL(topic, defaultHost string) string {
if strings.HasPrefix(topic, "http://") || strings.HasPrefix(topic, "https://") {
return topic
} else if strings.Contains(topic, "/") {
return fmt.Sprintf("https://%s", topic)
}
return fmt.Sprintf("%s/%s", defaultHost, topic)
}
// DetectContentType probes the byte array b and returns mime type and file extension. // DetectContentType probes the byte array b and returns mime type and file extension.
// The filename is only used to override certain special cases. // The filename is only used to override certain special cases.
func DetectContentType(b []byte, filename string) (mimeType string, ext string) { func DetectContentType(b []byte, filename string) (mimeType string, ext string) {