From 9514e97219e0f0e8be48963509fe6c28d1099f91 Mon Sep 17 00:00:00 2001 From: Philipp Heckel Date: Fri, 15 Jul 2022 16:52:37 -0400 Subject: [PATCH] Multipart encryption stuff --- server/errors.go | 2 ++ server/server.go | 67 ++++++++++++++++++++++++++++++++++++----- server/server_test.go | 70 +++++++++++++++++++++++++++++++++++++++++++ server/types.go | 4 +-- util/util.go | 2 ++ 5 files changed, 135 insertions(+), 10 deletions(-) diff --git a/server/errors.go b/server/errors.go index 28dbca3a..377e924f 100644 --- a/server/errors.go +++ b/server/errors.go @@ -52,11 +52,13 @@ var ( errHTTPBadRequestActionsInvalid = &errHTTP{40018, http.StatusBadRequest, "invalid request: actions invalid", "https://ntfy.sh/docs/publish/#action-buttons"} errHTTPBadRequestMatrixMessageInvalid = &errHTTP{40019, http.StatusBadRequest, "invalid request: Matrix JSON invalid", "https://ntfy.sh/docs/publish/#matrix-gateway"} errHTTPBadRequestMatrixPushkeyBaseURLMismatch = &errHTTP{40020, http.StatusBadRequest, "invalid request: push key must be prefixed with base URL", "https://ntfy.sh/docs/publish/#matrix-gateway"} + errHTTPBadRequestUnexpectedMultipartField = &errHTTP{40021, http.StatusBadRequest, "invalid request: unexpected multipart field", "https://ntfy.sh/docs/publish/#end-to-end-encryption"} errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""} errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication"} errHTTPForbidden = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication"} errHTTPEntityTooLargeAttachmentTooLarge = &errHTTP{41301, http.StatusRequestEntityTooLarge, "attachment too large, or bandwidth limit reached", "https://ntfy.sh/docs/publish/#limitations"} errHTTPEntityTooLargeMatrixRequestTooLarge = &errHTTP{41302, http.StatusRequestEntityTooLarge, "Matrix request is larger than the max allowed length", ""} + errHTTPEntityTooLargeEncryptedMessageTooLarge = &errHTTP{41303, http.StatusRequestEntityTooLarge, "encrypted message payload too large", "https://ntfy.sh/docs/publish/#end-to-end-encryption"} errHTTPTooManyRequestsLimitRequests = &errHTTP{42901, http.StatusTooManyRequests, "limit reached: too many requests, please be nice", "https://ntfy.sh/docs/publish/#limitations"} errHTTPTooManyRequestsLimitEmails = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"} errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"} diff --git a/server/server.go b/server/server.go index 74a31c4d..996a6d4a 100644 --- a/server/server.go +++ b/server/server.go @@ -415,7 +415,7 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) } messageID := matches[1] file := filepath.Join(s.config.AttachmentCacheDir, messageID) - stat, err := os.Stat(file) + stat, err := os.Stat(file) // TODO: Why is this here and not in fileCache?! if err != nil { return errHTTPNotFound } @@ -450,21 +450,25 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes if err != nil { return nil, err } - body, err := util.Peek(r.Body, s.config.MessageLimit) - if err != nil { - return nil, err - } m := newDefaultMessage(t.ID, "") cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, v, m) if err != nil { return nil, err } + var body *util.PeekedReadCloser + if m.Encoding == encodingJWE { + m = newEncryptedMessage(t.ID) + if body, err = s.handlePublishEncrypted(r, m); err != nil { + return nil, err + } + } else { + if body, err = util.Peek(r.Body, s.config.MessageLimit); err != nil { + return nil, err + } + } if m.PollID != "" { m = newPollRequestMessage(t.ID, m.PollID) } - if m.Encoding == encodingJWE { - m = newEncryptedMessage(t.ID, m.Message) - } if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil { return nil, err } @@ -525,6 +529,50 @@ func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v * return writeMatrixSuccess(w) } +func (s *Server) handlePublishEncrypted(r *http.Request, m *message) (body *util.PeekedReadCloser, err error) { + multipart := strings.HasPrefix(r.Header.Get("Content-Type"), "multipart/") + if multipart { + mp, err := r.MultipartReader() + if err != nil { + return nil, err + } + p, err := mp.NextPart() + if err != nil { + return nil, err + } else if p.FormName() != "message" { + return nil, errHTTPBadRequestUnexpectedMultipartField + } + messageBody, err := util.PeekLimit(p, s.config.MessageLimit) + if err == util.ErrLimitReached { + return nil, errHTTPEntityTooLargeEncryptedMessageTooLarge + } else if err != nil { + return nil, err + } + m.Message = string(messageBody.PeekedBytes) + p, err = mp.NextPart() + if err != nil { + return nil, err + } else if p.FormName() != "attachment" { + return nil, errHTTPBadRequestUnexpectedMultipartField + } + m.Attachment = &attachment{ + Name: "attachment.jwe", // Force handlePublishBody into "attachment" mode + } + body, err = util.Peek(p, s.config.MessageLimit) + if err != nil { + return nil, err + } + } else { + if body, err = util.PeekLimit(r.Body, s.config.MessageLimit); err == util.ErrLimitReached { + return nil, errHTTPEntityTooLargeEncryptedMessageTooLarge + } else if err != nil { + return nil, err + } + m.Message = string(body.PeekedBytes) + } + return body, nil +} + func (s *Server) sendToFirebase(v *visitor, m *message) { log.Debug("%s Publishing to Firebase", logMessagePrefix(v, m)) if err := s.firebaseClient.Send(v, m); err != nil { @@ -622,6 +670,9 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca m.Tags = append(m.Tags, strings.TrimSpace(s)) } } + if encoding := readParam(r, "x-encoding", "encoding"); encoding == encodingJWE { + m.Encoding = encoding + } delayStr := readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in") if delayStr != "" { if !cache { diff --git a/server/server_test.go b/server/server_test.go index d68cfa11..0384e78c 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -2,6 +2,7 @@ package server import ( "bufio" + "bytes" "context" "encoding/base64" "encoding/json" @@ -10,6 +11,7 @@ import ( "io" "log" "math/rand" + "mime/multipart" "net/http" "net/http/httptest" "path/filepath" @@ -1459,6 +1461,51 @@ func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) { log.Printf("Done: Waiting for all locks") } +func TestServer_PublishEncrypted_Simple(t *testing.T) { + s := newTestServer(t, newTestConfig(t)) + ciphertext := "eyJhbGciOiJkaXIiLCJlbmMiOiJBMjU2R0NNIn0..gSRYZeX6eBhlj13w.LOchcxFXwALXE2GqdoSwFJEXdMyEbLfLKV9geXr17WrAN-nH7ya1VQ_Y6ebT1w.2eyLaTUfc_rpKaZr4-5I1Q" + response := request(t, s, "PUT", "/mytopic", ciphertext, map[string]string{ + "Encoding": "jwe", + "Title": "this will be stripped", + }) + m := toMessage(t, response.Body.String()) + require.Equal(t, "jwe", m.Encoding) + require.Equal(t, "eyJhbGciOiJkaXIiLCJlbmMiOiJBMjU2R0NNIn0..gSRYZeX6eBhlj13w.LOchcxFXwALXE2GqdoSwFJEXdMyEbLfLKV9geXr17WrAN-nH7ya1VQ_Y6ebT1w.2eyLaTUfc_rpKaZr4-5I1Q", m.Message) + require.Equal(t, "", m.Title) +} + +func TestServer_PublishEncrypted_Simple_TooLarge(t *testing.T) { + s := newTestServer(t, newTestConfig(t)) + ciphertext := util.RandomString(5001) // > 4096 + response := request(t, s, "PUT", "/mytopic", ciphertext, map[string]string{ + "Encoding": "jwe", + }) + err := toHTTPError(t, response.Body.String()) + require.Equal(t, 413, err.HTTPCode) + require.Equal(t, 41303, err.Code) +} + +func TestServer_PublishEncrypted_WithAttachment(t *testing.T) { + s := newTestServer(t, newTestConfig(t)) + parts := map[string]string{ + "message": "eyJhbGciOiJkaXIiLCJlbmMiOiJBMjU2R0NNIn0..gSRYZeX6eBhlj13w.LOchcxFXwALXE2GqdoSwFJEXdMyEbLfLKV9geXr17WrAN-nH7ya1VQ_Y6ebT1w.2eyLaTUfc_rpKaZr4-5I1Q", + "attachment": "eyJhbGciOiJkaXIiLCJlbmMiOiJBMjU2R0NNIn0..vbe1Qv_-mKYbUgce.EfmOUIUi7lxXZG_o4bqXZ9pmpr1Rzs4Y5QLE2XD2_aw_SQ.y2hadrN5b2LEw7_PJHhbcA", + } + response := requestMultipart(t, s, "PUT", "/mytopic", parts, map[string]string{ + "Encoding": "jwe", + }) + m := toMessage(t, response.Body.String()) + require.Equal(t, "jwe", m.Encoding) + require.Equal(t, "eyJhbGciOiJkaXIiLCJlbmMiOiJBMjU2R0NNIn0..gSRYZeX6eBhlj13w.LOchcxFXwALXE2GqdoSwFJEXdMyEbLfLKV9geXr17WrAN-nH7ya1VQ_Y6ebT1w.2eyLaTUfc_rpKaZr4-5I1Q", m.Message) + require.Equal(t, "attachment.jwe", m.Attachment.Name) + require.Equal(t, "application/jose", m.Attachment.Type) + require.Equal(t, int64(127), m.Attachment.Size) + + file := filepath.Join(s.config.AttachmentCacheDir, m.ID) + require.FileExists(t, file) + require.Equal(t, "eyJhbGciOiJkaXIiLCJlbmMiOiJBMjU2R0NNIn0..vbe1Qv_-mKYbUgce.EfmOUIUi7lxXZG_o4bqXZ9pmpr1Rzs4Y5QLE2XD2_aw_SQ.y2hadrN5b2LEw7_PJHhbcA", readFile(t, file)) +} + func newTestConfig(t *testing.T) *Config { conf := NewConfig() conf.BaseURL = "http://127.0.0.1:12345" @@ -1489,6 +1536,29 @@ func request(t *testing.T, s *Server, method, url, body string, headers map[stri return rr } +func requestMultipart(t *testing.T, s *Server, method, url string, parts map[string]string, headers map[string]string) *httptest.ResponseRecorder { + var b bytes.Buffer + w := multipart.NewWriter(&b) + for k, v := range parts { + mw, _ := w.CreateFormField(k) + _, err := io.Copy(mw, strings.NewReader(v)) + require.Nil(t, err) + } + require.Nil(t, w.Close()) + rr := httptest.NewRecorder() + req, err := http.NewRequest(method, url, &b) + if err != nil { + t.Fatal(err) + } + req.RemoteAddr = "9.9.9.9" // Used for tests + req.Header.Set("Content-Type", w.FormDataContentType()) + for k, v := range headers { + req.Header.Set(k, v) + } + s.handle(rr, req) + return rr +} + func subscribe(t *testing.T, s *Server, url string, rr *httptest.ResponseRecorder) context.CancelFunc { ctx, cancel := context.WithCancel(context.Background()) req, err := http.NewRequestWithContext(ctx, "GET", url, nil) diff --git a/server/types.go b/server/types.go index d59e748b..dc488d23 100644 --- a/server/types.go +++ b/server/types.go @@ -115,8 +115,8 @@ func newPollRequestMessage(topic, pollID string) *message { return m } -func newEncryptedMessage(topic, msg string) *message { - m := newMessage(messageEvent, topic, msg) +func newEncryptedMessage(topic string) *message { + m := newMessage(messageEvent, topic, "") m.Encoding = encodingJWE return m } diff --git a/util/util.go b/util/util.go index ce18f093..efa8793e 100644 --- a/util/util.go +++ b/util/util.go @@ -177,6 +177,8 @@ func ShortTopicURL(s string) string { func DetectContentType(b []byte, filename string) (mimeType string, ext string) { if strings.HasSuffix(strings.ToLower(filename), ".apk") { return "application/vnd.android.package-archive", ".apk" + } else if strings.HasSuffix(strings.ToLower(filename), ".jwe") { + return "application/jose", ".jwe" } m := mimetype.Detect(b) mimeType, ext = m.String(), m.Extension()