From eb5b86ffe20a552c6f9e6a3a523539561741da74 Mon Sep 17 00:00:00 2001 From: Philipp Heckel Date: Sun, 2 Jan 2022 23:56:12 +0100 Subject: [PATCH] WIP: Attachments --- .gitignore | 1 + cmd/serve.go | 5 +- docs/publish.md | 1 + server/config.go | 10 ++- server/message.go | 23 ++++-- server/server.go | 128 ++++++++++++++++++++++++++----- server/server_test.go | 11 +-- util/content_type_writer.go | 41 ++++++++++ util/content_type_writer_test.go | 50 ++++++++++++ util/limit.go | 41 ++++++++++ util/limit_test.go | 63 ++++++++++++++- util/peak.go | 61 +++++++++++++++ util/peak_test.go | 55 +++++++++++++ 13 files changed, 444 insertions(+), 46 deletions(-) create mode 100644 util/content_type_writer.go create mode 100644 util/content_type_writer_test.go create mode 100644 util/peak.go create mode 100644 util/peak_test.go diff --git a/.gitignore b/.gitignore index 6dffcf55..6d12c730 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ build/ .idea/ server/docs/ tools/fbsend/fbsend +playground/ *.iml diff --git a/cmd/serve.go b/cmd/serve.go index 5545206f..f4161e1a 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -20,6 +20,7 @@ var flagsServe = []cli.Flag{ altsrc.NewStringFlag(&cli.StringFlag{Name: "firebase-key-file", Aliases: []string{"F"}, EnvVars: []string{"NTFY_FIREBASE_KEY_FILE"}, Usage: "Firebase credentials file; if set additionally publish to FCM topic"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"C"}, EnvVars: []string{"NTFY_CACHE_FILE"}, Usage: "cache file used for message caching"}), altsrc.NewDurationFlag(&cli.DurationFlag{Name: "cache-duration", Aliases: []string{"b"}, EnvVars: []string{"NTFY_CACHE_DURATION"}, Value: server.DefaultCacheDuration, Usage: "buffer messages for this time to allow `since` requests"}), + altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-cache-dir", EnvVars: []string{"NTFY_ATTACHMENT_CACHE_DIR"}, Usage: "cache directory for attached files"}), altsrc.NewDurationFlag(&cli.DurationFlag{Name: "keepalive-interval", Aliases: []string{"k"}, EnvVars: []string{"NTFY_KEEPALIVE_INTERVAL"}, Value: server.DefaultKeepaliveInterval, Usage: "interval of keepalive messages"}), altsrc.NewDurationFlag(&cli.DurationFlag{Name: "manager-interval", Aliases: []string{"m"}, EnvVars: []string{"NTFY_MANAGER_INTERVAL"}, Value: server.DefaultManagerInterval, Usage: "interval of for message pruning and stats printing"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-sender-addr", EnvVars: []string{"NTFY_SMTP_SENDER_ADDR"}, Usage: "SMTP server address (host:port) for outgoing emails"}), @@ -69,6 +70,7 @@ func execServe(c *cli.Context) error { firebaseKeyFile := c.String("firebase-key-file") cacheFile := c.String("cache-file") cacheDuration := c.Duration("cache-duration") + attachmentCacheDir := c.String("attachment-cache-dir") keepaliveInterval := c.Duration("keepalive-interval") managerInterval := c.Duration("manager-interval") smtpSenderAddr := c.String("smtp-sender-addr") @@ -117,6 +119,7 @@ func execServe(c *cli.Context) error { conf.FirebaseKeyFile = firebaseKeyFile conf.CacheFile = cacheFile conf.CacheDuration = cacheDuration + conf.AttachmentCacheDir = attachmentCacheDir conf.KeepaliveInterval = keepaliveInterval conf.ManagerInterval = managerInterval conf.SMTPSenderAddr = smtpSenderAddr @@ -126,7 +129,7 @@ func execServe(c *cli.Context) error { conf.SMTPServerListen = smtpServerListen conf.SMTPServerDomain = smtpServerDomain conf.SMTPServerAddrPrefix = smtpServerAddrPrefix - conf.GlobalTopicLimit = globalTopicLimit + conf.TotalTopicLimit = globalTopicLimit conf.VisitorSubscriptionLimit = visitorSubscriptionLimit conf.VisitorRequestLimitBurst = visitorRequestLimitBurst conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish diff --git a/docs/publish.md b/docs/publish.md index ec017e05..46a5a334 100644 --- a/docs/publish.md +++ b/docs/publish.md @@ -886,3 +886,4 @@ and can be passed as **HTTP headers** or **query parameters in the URL**. They a | `X-Email` | `X-E-Mail`, `Email`, `E-Mail`, `mail`, `e` | E-mail address for [e-mail notifications](#e-mail-notifications) | | `X-Cache` | `Cache` | Allows disabling [message caching](#message-caching) | | `X-Firebase` | `Firebase` | Allows disabling [sending to Firebase](#disable-firebase) | +| `X-UnifiedPush` | `UnifiedPush`, `up` | XXXXXXXXXXXXXXXX | diff --git a/server/config.go b/server/config.go index 30f937e9..68a911fb 100644 --- a/server/config.go +++ b/server/config.go @@ -14,6 +14,7 @@ const ( DefaultMinDelay = 10 * time.Second DefaultMaxDelay = 3 * 24 * time.Hour DefaultMessageLimit = 4096 + DefaultAttachmentSizeLimit = 5 * 1024 * 1024 DefaultFirebaseKeepaliveInterval = 3 * time.Hour // Not too frequently to save battery ) @@ -41,6 +42,8 @@ type Config struct { FirebaseKeyFile string CacheFile string CacheDuration time.Duration + AttachmentCacheDir string + AttachmentSizeLimit int64 KeepaliveInterval time.Duration ManagerInterval time.Duration AtSenderInterval time.Duration @@ -55,7 +58,8 @@ type Config struct { MessageLimit int MinDelay time.Duration MaxDelay time.Duration - GlobalTopicLimit int + TotalTopicLimit int + TotalAttachmentSizeLimit int64 VisitorRequestLimitBurst int VisitorRequestLimitReplenish time.Duration VisitorEmailLimitBurst int @@ -75,6 +79,8 @@ func NewConfig() *Config { FirebaseKeyFile: "", CacheFile: "", CacheDuration: DefaultCacheDuration, + AttachmentCacheDir: "", + AttachmentSizeLimit: DefaultAttachmentSizeLimit, KeepaliveInterval: DefaultKeepaliveInterval, ManagerInterval: DefaultManagerInterval, MessageLimit: DefaultMessageLimit, @@ -82,7 +88,7 @@ func NewConfig() *Config { MaxDelay: DefaultMaxDelay, AtSenderInterval: DefaultAtSenderInterval, FirebaseKeepaliveInterval: DefaultFirebaseKeepaliveInterval, - GlobalTopicLimit: DefaultGlobalTopicLimit, + TotalTopicLimit: DefaultGlobalTopicLimit, VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst, VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish, VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst, diff --git a/server/message.go b/server/message.go index ad870e09..2c3fb198 100644 --- a/server/message.go +++ b/server/message.go @@ -18,14 +18,21 @@ const ( // message represents a message published to a topic type message struct { - ID string `json:"id"` // Random message ID - Time int64 `json:"time"` // Unix time in seconds - Event string `json:"event"` // One of the above - Topic string `json:"topic"` - Priority int `json:"priority,omitempty"` - Tags []string `json:"tags,omitempty"` - Title string `json:"title,omitempty"` - Message string `json:"message,omitempty"` + ID string `json:"id"` // Random message ID + Time int64 `json:"time"` // Unix time in seconds + Event string `json:"event"` // One of the above + Topic string `json:"topic"` + Priority int `json:"priority,omitempty"` + Tags []string `json:"tags,omitempty"` + Title string `json:"title,omitempty"` + Message string `json:"message,omitempty"` + Attachment *attachment `json:"attachment,omitempty"` +} + +type attachment struct { + Name string `json:"name"` + Type string `json:"type"` + URL string `json:"url"` } // messageEncoder is a function that knows how to encode a message diff --git a/server/server.go b/server/server.go index 9cf76dea..b8ca70f7 100644 --- a/server/server.go +++ b/server/server.go @@ -15,14 +15,18 @@ import ( "html/template" "io" "log" + "mime" "net" "net/http" "net/http/httptest" + "os" + "path/filepath" "regexp" "strconv" "strings" "sync" "time" + "unicode/utf8" ) // TODO add "max messages in a topic" limit @@ -96,7 +100,8 @@ var ( staticRegex = regexp.MustCompile(`^/static/.+`) docsRegex = regexp.MustCompile(`^/docs(|/.*)$`) - disallowedTopics = []string{"docs", "static"} + fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`) + disallowedTopics = []string{"docs", "static", "file"} templateFnMap = template.FuncMap{ "durationToHuman": util.DurationToHuman, @@ -117,22 +122,26 @@ var ( docsStaticFs embed.FS docsStaticCached = &util.CachingEmbedFS{ModTime: time.Now(), FS: docsStaticFs} - errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""} - errHTTPTooManyRequestsLimitRequests = &errHTTP{42901, http.StatusTooManyRequests, "limit reached: too many requests, please be nice", "https://ntfy.sh/docs/publish/#limitations"} - errHTTPTooManyRequestsLimitEmails = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"} - errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"} - errHTTPTooManyRequestsLimitGlobalTopics = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"} - errHTTPBadRequestEmailDisabled = &errHTTP{40001, http.StatusBadRequest, "e-mail notifications are not enabled", "https://ntfy.sh/docs/config/#e-mail-notifications"} - errHTTPBadRequestDelayNoCache = &errHTTP{40002, http.StatusBadRequest, "cannot disable cache for delayed message", ""} - errHTTPBadRequestDelayNoEmail = &errHTTP{40003, http.StatusBadRequest, "delayed e-mail notifications are not supported", ""} - errHTTPBadRequestDelayCannotParse = &errHTTP{40004, http.StatusBadRequest, "invalid delay parameter: unable to parse delay", "https://ntfy.sh/docs/publish/#scheduled-delivery"} - errHTTPBadRequestDelayTooSmall = &errHTTP{40005, http.StatusBadRequest, "invalid delay parameter: too small, please refer to the docs", "https://ntfy.sh/docs/publish/#scheduled-delivery"} - errHTTPBadRequestDelayTooLarge = &errHTTP{40006, http.StatusBadRequest, "invalid delay parameter: too large, please refer to the docs", "https://ntfy.sh/docs/publish/#scheduled-delivery"} - errHTTPBadRequestPriorityInvalid = &errHTTP{40007, http.StatusBadRequest, "invalid priority parameter", "https://ntfy.sh/docs/publish/#message-priority"} - errHTTPBadRequestSinceInvalid = &errHTTP{40008, http.StatusBadRequest, "invalid since parameter", "https://ntfy.sh/docs/subscribe/api/#fetch-cached-messages"} - errHTTPBadRequestTopicInvalid = &errHTTP{40009, http.StatusBadRequest, "invalid topic: path invalid", ""} - errHTTPBadRequestTopicDisallowed = &errHTTP{40010, http.StatusBadRequest, "invalid topic: topic name is disallowed", ""} - errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""} + errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""} + errHTTPTooManyRequestsLimitRequests = &errHTTP{42901, http.StatusTooManyRequests, "limit reached: too many requests, please be nice", "https://ntfy.sh/docs/publish/#limitations"} + errHTTPTooManyRequestsLimitEmails = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"} + errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"} + errHTTPTooManyRequestsLimitGlobalTopics = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"} + errHTTPBadRequestEmailDisabled = &errHTTP{40001, http.StatusBadRequest, "e-mail notifications are not enabled", "https://ntfy.sh/docs/config/#e-mail-notifications"} + errHTTPBadRequestDelayNoCache = &errHTTP{40002, http.StatusBadRequest, "cannot disable cache for delayed message", ""} + errHTTPBadRequestDelayNoEmail = &errHTTP{40003, http.StatusBadRequest, "delayed e-mail notifications are not supported", ""} + errHTTPBadRequestDelayCannotParse = &errHTTP{40004, http.StatusBadRequest, "invalid delay parameter: unable to parse delay", "https://ntfy.sh/docs/publish/#scheduled-delivery"} + errHTTPBadRequestDelayTooSmall = &errHTTP{40005, http.StatusBadRequest, "invalid delay parameter: too small, please refer to the docs", "https://ntfy.sh/docs/publish/#scheduled-delivery"} + errHTTPBadRequestDelayTooLarge = &errHTTP{40006, http.StatusBadRequest, "invalid delay parameter: too large, please refer to the docs", "https://ntfy.sh/docs/publish/#scheduled-delivery"} + errHTTPBadRequestPriorityInvalid = &errHTTP{40007, http.StatusBadRequest, "invalid priority parameter", "https://ntfy.sh/docs/publish/#message-priority"} + errHTTPBadRequestSinceInvalid = &errHTTP{40008, http.StatusBadRequest, "invalid since parameter", "https://ntfy.sh/docs/subscribe/api/#fetch-cached-messages"} + errHTTPBadRequestTopicInvalid = &errHTTP{40009, http.StatusBadRequest, "invalid topic: path invalid", ""} + errHTTPBadRequestTopicDisallowed = &errHTTP{40010, http.StatusBadRequest, "invalid topic: topic name is disallowed", ""} + errHTTPBadRequestAttachmentsDisallowed = &errHTTP{40011, http.StatusBadRequest, "attachments disallowed", ""} + errHTTPBadRequestAttachmentsPublishDisallowed = &errHTTP{40011, http.StatusBadRequest, "invalid message: invalid encoding or too large, and attachments are not allowed", ""} + errHTTPBadRequestMessageTooLarge = &errHTTP{40013, http.StatusBadRequest, "invalid message: too large", ""} + errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""} + errHTTPInternalErrorInvalidFilePath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid file path", ""} ) const ( @@ -163,6 +172,11 @@ func New(conf *Config) (*Server, error) { if err != nil { return nil, err } + if conf.AttachmentCacheDir != "" { + if err := os.MkdirAll(conf.AttachmentCacheDir, 0700); err != nil { + return nil, err + } + } return &Server{ config: conf, cache: cache, @@ -302,6 +316,8 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error { return s.handleStatic(w, r) } else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) { return s.handleDocs(w, r) + } else if r.Method == http.MethodGet && fileRegex.MatchString(r.URL.Path) { + return s.handleFile(w, r) } else if r.Method == http.MethodOptions { return s.handleOptions(w, r) } else if r.Method == http.MethodGet && topicPathRegex.MatchString(r.URL.Path) { @@ -357,17 +373,45 @@ func (s *Server) handleDocs(w http.ResponseWriter, r *http.Request) error { return nil } +func (s *Server) handleFile(w http.ResponseWriter, r *http.Request) error { + if s.config.AttachmentCacheDir == "" { + return errHTTPBadRequestAttachmentsDisallowed + } + matches := fileRegex.FindStringSubmatch(r.URL.Path) + if len(matches) != 2 { + return errHTTPInternalErrorInvalidFilePath + } + messageID := matches[1] + file := filepath.Join(s.config.AttachmentCacheDir, messageID) + stat, err := os.Stat(file) + if err != nil { + return errHTTPNotFound + } + w.Header().Set("Length", fmt.Sprintf("%d", stat.Size())) + f, err := os.Open(file) + if err != nil { + return err + } + defer f.Close() + _, err = io.Copy(util.NewContentTypeWriter(w), f) + return err +} + func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error { t, err := s.topicFromPath(r.URL.Path) if err != nil { return err } - reader := io.LimitReader(r.Body, int64(s.config.MessageLimit)) - b, err := io.ReadAll(reader) + body, err := util.Peak(r.Body, s.config.MessageLimit) if err != nil { return err } - m := newDefaultMessage(t.ID, strings.TrimSpace(string(b))) + m := newDefaultMessage(t.ID, "") + if !body.LimitReached && utf8.Valid(body.PeakedBytes) { + m.Message = strings.TrimSpace(string(body.PeakedBytes)) + } else if err := s.writeAttachment(v, m, body); err != nil { + return err + } cache, firebase, email, err := s.parsePublishParams(r, m) if err != nil { return err @@ -478,6 +522,48 @@ func readParam(r *http.Request, names ...string) string { return "" } +func (s *Server) writeAttachment(v *visitor, m *message, body *util.PeakedReadCloser) error { + if s.config.AttachmentCacheDir == "" || !util.FileExists(s.config.AttachmentCacheDir) { + return errHTTPBadRequestAttachmentsPublishDisallowed + } + contentType := http.DetectContentType(body.PeakedBytes) + exts, err := mime.ExtensionsByType(contentType) + if err != nil { + return err + } + ext := ".bin" + if len(exts) > 0 { + ext = exts[0] + } + filename := fmt.Sprintf("attachment%s", ext) + file := filepath.Join(s.config.AttachmentCacheDir, m.ID) + f, err := os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600) + if err != nil { + return err + } + defer f.Close() + fileSizeLimiter := util.NewLimiter(s.config.AttachmentSizeLimit) + limitWriter := util.NewLimitWriter(f, fileSizeLimiter) + if _, err := io.Copy(limitWriter, body); err != nil { + os.Remove(file) + if err == util.ErrLimitReached { + return errHTTPBadRequestMessageTooLarge + } + return err + } + if err := f.Close(); err != nil { + os.Remove(file) + return err + } + m.Message = fmt.Sprintf("You received a file: %s", filename) + m.Attachment = &attachment{ + Name: filename, + Type: contentType, + URL: fmt.Sprintf("%s/file/%s%s", s.config.BaseURL, m.ID, ext), + } + return nil +} + func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *visitor) error { encoder := func(msg *message) (string, error) { var buf bytes.Buffer @@ -691,7 +777,7 @@ func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) { return nil, errHTTPBadRequestTopicDisallowed } if _, ok := s.topics[id]; !ok { - if len(s.topics) >= s.config.GlobalTopicLimit { + if len(s.topics) >= s.config.TotalTopicLimit { return nil, errHTTPTooManyRequestsLimitGlobalTopics } s.topics[id] = newTopic(id) diff --git a/server/server_test.go b/server/server_test.go index e713e604..f8a1a8a2 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -165,17 +165,8 @@ func TestServer_PublishLargeMessage(t *testing.T) { s := newTestServer(t, newTestConfig(t)) body := strings.Repeat("this is a large message", 5000) - truncated := body[0:4096] response := request(t, s, "PUT", "/mytopic", body, nil) - msg := toMessage(t, response.Body.String()) - require.NotEmpty(t, msg.ID) - require.Equal(t, truncated, msg.Message) - require.Equal(t, 4096, len(msg.Message)) - - response = request(t, s, "GET", "/mytopic/json?poll=1", "", nil) - messages := toMessages(t, response.Body.String()) - require.Equal(t, 1, len(messages)) - require.Equal(t, truncated, messages[0].Message) + require.Equal(t, 400, response.Code) } func TestServer_PublishPriority(t *testing.T) { diff --git a/util/content_type_writer.go b/util/content_type_writer.go new file mode 100644 index 00000000..fb3c43f8 --- /dev/null +++ b/util/content_type_writer.go @@ -0,0 +1,41 @@ +package util + +import ( + "net/http" + "strings" +) + +// ContentTypeWriter is an implementation of http.ResponseWriter that will detect the content type and set the +// Content-Type and (optionally) Content-Disposition headers accordingly. +// +// It will always set a Content-Type based on http.DetectContentType, but will never send the "text/html" +// content type. +type ContentTypeWriter struct { + w http.ResponseWriter + sniffed bool +} + +// NewContentTypeWriter creates a new ContentTypeWriter +func NewContentTypeWriter(w http.ResponseWriter) *ContentTypeWriter { + return &ContentTypeWriter{w, false} +} + +func (w *ContentTypeWriter) Write(p []byte) (n int, err error) { + if w.sniffed { + return w.w.Write(p) + } + // Detect and set Content-Type header + // Fix content types that we don't want to inline-render in the browser. In particular, + // we don't want to render HTML in the browser for security reasons. + contentType := http.DetectContentType(p) + if strings.HasPrefix(contentType, "text/html") { + contentType = strings.ReplaceAll(contentType, "text/html", "text/plain") + } else if contentType == "application/octet-stream" { + contentType = "" // Reset to let downstream http.ResponseWriter take care of it + } + if contentType != "" { + w.w.Header().Set("Content-Type", contentType) + } + w.sniffed = true + return w.w.Write(p) +} diff --git a/util/content_type_writer_test.go b/util/content_type_writer_test.go new file mode 100644 index 00000000..08dd751b --- /dev/null +++ b/util/content_type_writer_test.go @@ -0,0 +1,50 @@ +package util + +import ( + "crypto/rand" + "github.com/stretchr/testify/require" + "net/http/httptest" + "testing" +) + +func TestSniffWriter_WriteHTML(t *testing.T) { + rr := httptest.NewRecorder() + sw := NewContentTypeWriter(rr) + sw.Write([]byte("")) + require.Equal(t, "text/plain; charset=utf-8", rr.Header().Get("Content-Type")) +} + +func TestSniffWriter_WriteTwoWriteCalls(t *testing.T) { + rr := httptest.NewRecorder() + sw := NewContentTypeWriter(rr) + sw.Write([]byte{0x25, 0x50, 0x44, 0x46, 0x2d, 0x11, 0x22, 0x33}) + sw.Write([]byte("")) + require.Equal(t, "application/pdf", rr.Header().Get("Content-Type")) +} + +func TestSniffWriter_NoSniffWriterWriteHTML(t *testing.T) { + // This test just makes sure that without the sniff-w, we would get text/html + + rr := httptest.NewRecorder() + rr.Write([]byte("")) + require.Equal(t, "text/html; charset=utf-8", rr.Header().Get("Content-Type")) +} + +func TestSniffWriter_WriteHTMLSplitIntoTwoWrites(t *testing.T) { + // This test shows how splitting the HTML into two Write() calls will still yield text/plain + + rr := httptest.NewRecorder() + sw := NewContentTypeWriter(rr) + sw.Write([]byte("alert('hi')")) + require.Equal(t, "text/plain; charset=utf-8", rr.Header().Get("Content-Type")) +} + +func TestSniffWriter_WriteUnknownMimeType(t *testing.T) { + rr := httptest.NewRecorder() + sw := NewContentTypeWriter(rr) + randomBytes := make([]byte, 199) + rand.Read(randomBytes) + sw.Write(randomBytes) + require.Equal(t, "application/octet-stream", rr.Header().Get("Content-Type")) +} diff --git a/util/limit.go b/util/limit.go index e5561247..bac3c155 100644 --- a/util/limit.go +++ b/util/limit.go @@ -2,6 +2,7 @@ package util import ( "errors" + "io" "sync" ) @@ -58,3 +59,43 @@ func (l *Limiter) Value() int64 { defer l.mu.Unlock() return l.value } + +// Limit returns the defined limit +func (l *Limiter) Limit() int64 { + return l.limit +} + +// LimitWriter implements an io.Writer that will pass through all Write calls to the underlying +// writer w until any of the limiter's limit is reached, at which point a Write will return ErrLimitReached. +// Each limiter's value is increased with every write. +type LimitWriter struct { + w io.Writer + written int64 + limiters []*Limiter + mu sync.Mutex +} + +// NewLimitWriter creates a new LimitWriter +func NewLimitWriter(w io.Writer, limiters ...*Limiter) *LimitWriter { + return &LimitWriter{ + w: w, + limiters: limiters, + } +} + +// Write passes through all writes to the underlying writer until any of the given limiter's limit is reached +func (w *LimitWriter) Write(p []byte) (n int, err error) { + w.mu.Lock() + defer w.mu.Unlock() + for i := 0; i < len(w.limiters); i++ { + if err := w.limiters[i].Add(int64(len(p))); err != nil { + for j := i - 1; j >= 0; j-- { + w.limiters[j].Sub(int64(len(p))) + } + return 0, ErrLimitReached + } + } + n, err = w.w.Write(p) + w.written += int64(n) + return +} diff --git a/util/limit_test.go b/util/limit_test.go index f6d56c6d..4f07e00f 100644 --- a/util/limit_test.go +++ b/util/limit_test.go @@ -1,6 +1,7 @@ package util import ( + "bytes" "testing" ) @@ -17,14 +18,68 @@ func TestLimiter_Add(t *testing.T) { } } -func TestLimiter_AddSub(t *testing.T) { +func TestLimiter_AddSet(t *testing.T) { l := NewLimiter(10) l.Add(5) if l.Value() != 5 { t.Fatalf("expected value to be %d, got %d", 5, l.Value()) } - l.Sub(2) - if l.Value() != 3 { - t.Fatalf("expected value to be %d, got %d", 3, l.Value()) + l.Set(7) + if l.Value() != 7 { + t.Fatalf("expected value to be %d, got %d", 7, l.Value()) + } +} + +func TestLimitWriter_WriteNoLimiter(t *testing.T) { + var buf bytes.Buffer + lw := NewLimitWriter(&buf) + if _, err := lw.Write(make([]byte, 10)); err != nil { + t.Fatal(err) + } + if _, err := lw.Write(make([]byte, 1)); err != nil { + t.Fatal(err) + } + if buf.Len() != 11 { + t.Fatalf("expected buffer length to be %d, got %d", 11, buf.Len()) + } +} + +func TestLimitWriter_WriteOneLimiter(t *testing.T) { + var buf bytes.Buffer + l := NewLimiter(10) + lw := NewLimitWriter(&buf, l) + if _, err := lw.Write(make([]byte, 10)); err != nil { + t.Fatal(err) + } + if _, err := lw.Write(make([]byte, 1)); err != ErrLimitReached { + t.Fatalf("expected ErrLimitReached, got %#v", err) + } + if buf.Len() != 10 { + t.Fatalf("expected buffer length to be %d, got %d", 10, buf.Len()) + } + if l.Value() != 10 { + t.Fatalf("expected limiter value to be %d, got %d", 10, l.Value()) + } +} + +func TestLimitWriter_WriteTwoLimiters(t *testing.T) { + var buf bytes.Buffer + l1 := NewLimiter(11) + l2 := NewLimiter(9) + lw := NewLimitWriter(&buf, l1, l2) + if _, err := lw.Write(make([]byte, 8)); err != nil { + t.Fatal(err) + } + if _, err := lw.Write(make([]byte, 2)); err != ErrLimitReached { + t.Fatalf("expected ErrLimitReached, got %#v", err) + } + if buf.Len() != 8 { + t.Fatalf("expected buffer length to be %d, got %d", 8, buf.Len()) + } + if l1.Value() != 8 { + t.Fatalf("expected limiter 1 value to be %d, got %d", 8, l1.Value()) + } + if l2.Value() != 8 { + t.Fatalf("expected limiter 2 value to be %d, got %d", 8, l2.Value()) } } diff --git a/util/peak.go b/util/peak.go new file mode 100644 index 00000000..100c269b --- /dev/null +++ b/util/peak.go @@ -0,0 +1,61 @@ +package util + +import ( + "bytes" + "io" + "strings" +) + +// PeakedReadCloser is a ReadCloser that allows peaking into a stream and buffering it in memory. +// It can be instantiated using the Peak function. After a stream has been peaked, it can still be fully +// read by reading the PeakedReadCloser. It first drained from the memory buffer, and then from the remaining +// underlying reader. +type PeakedReadCloser struct { + PeakedBytes []byte + LimitReached bool + peaked io.Reader + underlying io.ReadCloser + closed bool +} + +// Peak reads the underlying ReadCloser into memory up until the limit and returns a PeakedReadCloser +func Peak(underlying io.ReadCloser, limit int) (*PeakedReadCloser, error) { + if underlying == nil { + underlying = io.NopCloser(strings.NewReader("")) + } + peaked := make([]byte, limit) + read, err := io.ReadFull(underlying, peaked) + if err != nil && err != io.ErrUnexpectedEOF && err != io.EOF { + return nil, err + } + return &PeakedReadCloser{ + PeakedBytes: peaked[:read], + LimitReached: read == limit, + underlying: underlying, + peaked: bytes.NewReader(peaked[:read]), + closed: false, + }, nil +} + +// Read reads from the peaked bytes and then from the underlying stream +func (r *PeakedReadCloser) Read(p []byte) (n int, err error) { + if r.closed { + return 0, io.EOF + } + n, err = r.peaked.Read(p) + if err == io.EOF { + return r.underlying.Read(p) + } else if err != nil { + return 0, err + } + return +} + +// Close closes the underlying stream +func (r *PeakedReadCloser) Close() error { + if r.closed { + return io.EOF + } + r.closed = true + return r.underlying.Close() +} diff --git a/util/peak_test.go b/util/peak_test.go new file mode 100644 index 00000000..76995179 --- /dev/null +++ b/util/peak_test.go @@ -0,0 +1,55 @@ +package util + +import ( + "github.com/stretchr/testify/require" + "io" + "strings" + "testing" +) + +func TestPeak_LimitReached(t *testing.T) { + underlying := io.NopCloser(strings.NewReader("1234567890")) + peaked, err := Peak(underlying, 5) + if err != nil { + t.Fatal(err) + } + require.Equal(t, []byte("12345"), peaked.PeakedBytes) + require.Equal(t, true, peaked.LimitReached) + + all, err := io.ReadAll(peaked) + if err != nil { + t.Fatal(err) + } + require.Equal(t, []byte("1234567890"), all) + require.Equal(t, []byte("12345"), peaked.PeakedBytes) + require.Equal(t, true, peaked.LimitReached) +} + +func TestPeak_LimitNotReached(t *testing.T) { + underlying := io.NopCloser(strings.NewReader("1234567890")) + peaked, err := Peak(underlying, 15) + if err != nil { + t.Fatal(err) + } + all, err := io.ReadAll(peaked) + if err != nil { + t.Fatal(err) + } + require.Equal(t, []byte("1234567890"), all) + require.Equal(t, []byte("1234567890"), peaked.PeakedBytes) + require.Equal(t, false, peaked.LimitReached) +} + +func TestPeak_Nil(t *testing.T) { + peaked, err := Peak(nil, 15) + if err != nil { + t.Fatal(err) + } + all, err := io.ReadAll(peaked) + if err != nil { + t.Fatal(err) + } + require.Equal(t, []byte(""), all) + require.Equal(t, []byte(""), peaked.PeakedBytes) + require.Equal(t, false, peaked.LimitReached) +}