diff --git a/server/server.go b/server/server.go index 996a6d4a..d424c1b9 100644 --- a/server/server.go +++ b/server/server.go @@ -96,6 +96,8 @@ const ( defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment encodingBase64 = "base64" encodingJWE = "jwe" + multipartFieldMessage = "message" + multipartFieldAttachment = "attachment" ) // WebSocket constants @@ -539,8 +541,8 @@ func (s *Server) handlePublishEncrypted(r *http.Request, m *message) (body *util p, err := mp.NextPart() if err != nil { return nil, err - } else if p.FormName() != "message" { - return nil, errHTTPBadRequestUnexpectedMultipartField + } else if p.FormName() != multipartFieldMessage { + return nil, wrapErrHTTP(errHTTPBadRequestUnexpectedMultipartField, "expected '%s', got '%s'", multipartFieldMessage, p.FormName()) } messageBody, err := util.PeekLimit(p, s.config.MessageLimit) if err == util.ErrLimitReached { @@ -552,11 +554,11 @@ func (s *Server) handlePublishEncrypted(r *http.Request, m *message) (body *util p, err = mp.NextPart() if err != nil { return nil, err - } else if p.FormName() != "attachment" { - return nil, errHTTPBadRequestUnexpectedMultipartField + } else if p.FormName() != multipartFieldAttachment { + return nil, wrapErrHTTP(errHTTPBadRequestUnexpectedMultipartField, "expected '%s', got '%s'", multipartFieldAttachment, p.FormName()) } m.Attachment = &attachment{ - Name: "attachment.jwe", // Force handlePublishBody into "attachment" mode + Name: "attachment.jwe", // Force handlePublishBody into "attachment" mode; .jwe forces application/jose type } body, err = util.Peek(p, s.config.MessageLimit) if err != nil { diff --git a/server/server_test.go b/server/server_test.go index 0384e78c..9edbd9e7 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1487,9 +1487,9 @@ func TestServer_PublishEncrypted_Simple_TooLarge(t *testing.T) { func TestServer_PublishEncrypted_WithAttachment(t *testing.T) { s := newTestServer(t, newTestConfig(t)) - parts := map[string]string{ - "message": "eyJhbGciOiJkaXIiLCJlbmMiOiJBMjU2R0NNIn0..gSRYZeX6eBhlj13w.LOchcxFXwALXE2GqdoSwFJEXdMyEbLfLKV9geXr17WrAN-nH7ya1VQ_Y6ebT1w.2eyLaTUfc_rpKaZr4-5I1Q", - "attachment": "eyJhbGciOiJkaXIiLCJlbmMiOiJBMjU2R0NNIn0..vbe1Qv_-mKYbUgce.EfmOUIUi7lxXZG_o4bqXZ9pmpr1Rzs4Y5QLE2XD2_aw_SQ.y2hadrN5b2LEw7_PJHhbcA", + parts := []mpart{ + {"message", "eyJhbGciOiJkaXIiLCJlbmMiOiJBMjU2R0NNIn0..gSRYZeX6eBhlj13w.LOchcxFXwALXE2GqdoSwFJEXdMyEbLfLKV9geXr17WrAN-nH7ya1VQ_Y6ebT1w.2eyLaTUfc_rpKaZr4-5I1Q"}, + {"attachment", "eyJhbGciOiJkaXIiLCJlbmMiOiJBMjU2R0NNIn0..vbe1Qv_-mKYbUgce.EfmOUIUi7lxXZG_o4bqXZ9pmpr1Rzs4Y5QLE2XD2_aw_SQ.y2hadrN5b2LEw7_PJHhbcA"}, } response := requestMultipart(t, s, "PUT", "/mytopic", parts, map[string]string{ "Encoding": "jwe", @@ -1506,6 +1506,37 @@ func TestServer_PublishEncrypted_WithAttachment(t *testing.T) { require.Equal(t, "eyJhbGciOiJkaXIiLCJlbmMiOiJBMjU2R0NNIn0..vbe1Qv_-mKYbUgce.EfmOUIUi7lxXZG_o4bqXZ9pmpr1Rzs4Y5QLE2XD2_aw_SQ.y2hadrN5b2LEw7_PJHhbcA", readFile(t, file)) } +func TestServer_PublishEncrypted_WithAttachment_TooLarge_Attachment(t *testing.T) { + c := newTestConfig(t) + c.AttachmentFileSizeLimit = 5000 + s := newTestServer(t, c) + parts := []mpart{ + {"message", "eyJhbGciOiJkaXIiLCJlbmMiOiJBMjU2R0NNIn0..gSRYZeX6eBhlj13w.LOchcxFXwALXE2GqdoSwFJEXdMyEbLfLKV9geXr17WrAN-nH7ya1VQ_Y6ebT1w.2eyLaTUfc_rpKaZr4-5I1Q"}, + {"attachment", strings.Repeat("a", 5001)}, // > 5000 + } + response := requestMultipart(t, s, "PUT", "/mytopic", parts, map[string]string{ + "Encoding": "jwe", + }) + err := toHTTPError(t, response.Body.String()) + require.Equal(t, 413, err.HTTPCode) + require.Equal(t, 41301, err.Code) +} + +func TestServer_PublishEncrypted_WithAttachment_TooLarge_Message(t *testing.T) { + s := newTestServer(t, newTestConfig(t)) + parts := []mpart{ + {"message", strings.Repeat("a", 5000)}, + {"attachment", "eyJhbGciOiJkaXIiLCJlbmMiOiJBMjU2R0NNIn0..gSRYZeX6eBhlj13w.LOchcxFXwALXE2GqdoSwFJEXdMyEbLfLKV9geXr17WrAN-nH7ya1VQ_Y6ebT1w.2eyLaTUfc_rpKaZr4-5I1Q"}, + } + response := requestMultipart(t, s, "PUT", "/mytopic", parts, map[string]string{ + "Encoding": "jwe", + }) + err := toHTTPError(t, response.Body.String()) + log.Printf(err.Error()) + require.Equal(t, 413, err.HTTPCode) + require.Equal(t, 41303, err.Code) +} + func newTestConfig(t *testing.T) *Config { conf := NewConfig() conf.BaseURL = "http://127.0.0.1:12345" @@ -1536,12 +1567,16 @@ func request(t *testing.T, s *Server, method, url, body string, headers map[stri return rr } -func requestMultipart(t *testing.T, s *Server, method, url string, parts map[string]string, headers map[string]string) *httptest.ResponseRecorder { +type mpart struct { + key, value string +} + +func requestMultipart(t *testing.T, s *Server, method, url string, parts []mpart, headers map[string]string) *httptest.ResponseRecorder { var b bytes.Buffer w := multipart.NewWriter(&b) - for k, v := range parts { - mw, _ := w.CreateFormField(k) - _, err := io.Copy(mw, strings.NewReader(v)) + for _, part := range parts { + mw, _ := w.CreateFormField(part.key) + _, err := io.Copy(mw, strings.NewReader(part.value)) require.Nil(t, err) } require.Nil(t, w.Close())