diff --git a/server/server.go b/server/server.go index 1e1e96fa..aebc216a 100644 --- a/server/server.go +++ b/server/server.go @@ -139,7 +139,7 @@ var ( errHTTPBadRequestTopicInvalid = &errHTTP{40009, http.StatusBadRequest, "invalid topic: path invalid", ""} errHTTPBadRequestTopicDisallowed = &errHTTP{40010, http.StatusBadRequest, "invalid topic: topic name is disallowed", ""} errHTTPBadRequestMessageNotUTF8 = &errHTTP{40011, http.StatusBadRequest, "invalid message: message must be UTF-8 encoded", ""} - errHTTPBadRequestMessageTooLarge = &errHTTP{40012, http.StatusBadRequest, "invalid message: too large", ""} + errHTTPBadRequestAttachmentTooLarge = &errHTTP{40012, http.StatusBadRequest, "invalid request: attachment too large", ""} errHTTPBadRequestAttachmentURLInvalid = &errHTTP{40013, http.StatusBadRequest, "invalid request: attachment URL is invalid", ""} errHTTPBadRequestAttachmentURLPeakGeneral = &errHTTP{40014, http.StatusBadRequest, "invalid request: attachment URL peak failed", ""} errHTTPBadRequestAttachmentURLPeakNon2xx = &errHTTP{40015, http.StatusBadRequest, "invalid request: attachment URL peak failed with non-2xx status code", ""} @@ -458,7 +458,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito if err := maybePeakAttachmentURL(m); err != nil { return err } - if err := s.handlePublishBody(v, m, body); err != nil { + if err := s.handlePublishBody(r, v, m, body); err != nil { return err } if m.Message == "" { @@ -592,15 +592,15 @@ func readParam(r *http.Request, names ...string) string { // If file.txt is <= 4096 (message limit) and valid UTF-8, treat it as a message // 4. curl -T file.txt ntfy.sh/mytopic // If file.txt is > message limit, treat it as an attachment -func (s *Server) handlePublishBody(v *visitor, m *message, body *util.PeakedReadCloser) error { +func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *message, body *util.PeakedReadCloser) error { if m.Attachment != nil && m.Attachment.URL != "" { return s.handleBodyAsMessage(m, body) // Case 1 } else if m.Attachment != nil && m.Attachment.Name != "" { - return s.handleBodyAsAttachment(v, m, body) // Case 2 + return s.handleBodyAsAttachment(r, v, m, body) // Case 2 } else if !body.LimitReached && utf8.Valid(body.PeakedBytes) { return s.handleBodyAsMessage(m, body) // Case 3 } - return s.handleBodyAsAttachment(v, m, body) // Case 4 + return s.handleBodyAsAttachment(r, v, m, body) // Case 4 } func (s *Server) handleBodyAsMessage(m *message, body *util.PeakedReadCloser) error { @@ -616,16 +616,27 @@ func (s *Server) handleBodyAsMessage(m *message, body *util.PeakedReadCloser) er return nil } -func (s *Server) handleBodyAsAttachment(v *visitor, m *message, body *util.PeakedReadCloser) error { +func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, body *util.PeakedReadCloser) error { if s.fileCache == nil { return errHTTPBadRequestAttachmentsDisallowed } else if m.Time > time.Now().Add(s.config.AttachmentExpiryDuration).Unix() { return errHTTPBadRequestAttachmentsExpiryBeforeDelivery } + visitorAttachmentsSize, err := s.cache.AttachmentsSize(v.ip) + if err != nil { + return err + } + remainingVisitorAttachmentSize := s.config.VisitorAttachmentTotalSizeLimit - visitorAttachmentsSize + contentLengthStr := r.Header.Get("Content-Length") + if contentLengthStr != "" { // Early "do-not-trust" check, hard limit see below + contentLength, err := strconv.ParseInt(contentLengthStr, 10, 64) + if err == nil && (contentLength > remainingVisitorAttachmentSize || contentLength > s.config.AttachmentFileSizeLimit) { + return errHTTPBadRequestAttachmentTooLarge + } + } if m.Attachment == nil { m.Attachment = &attachment{} } - var err error var ext string m.Attachment.Owner = v.ip // Important for attachment rate limiting m.Attachment.Expires = time.Now().Add(s.config.AttachmentExpiryDuration).Unix() @@ -637,18 +648,12 @@ func (s *Server) handleBodyAsAttachment(v *visitor, m *message, body *util.Peake if m.Message == "" { m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name) } - visitorAttachmentsSize, err := s.cache.AttachmentsSize(v.ip) - if err != nil { - return err - } - remainingVisitorAttachmentSize := s.config.VisitorAttachmentTotalSizeLimit - visitorAttachmentsSize m.Attachment.Size, err = s.fileCache.Write(m.ID, body, util.NewLimiter(remainingVisitorAttachmentSize)) if err == util.ErrLimitReached { - return errHTTPBadRequestMessageTooLarge + return errHTTPBadRequestAttachmentTooLarge } else if err != nil { return err } - return nil }