diff --git a/.gitignore b/.gitignore
index 6dffcf55..6d12c730 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,4 +3,5 @@ build/
.idea/
server/docs/
tools/fbsend/fbsend
+playground/
*.iml
diff --git a/cmd/serve.go b/cmd/serve.go
index 5545206f..f4161e1a 100644
--- a/cmd/serve.go
+++ b/cmd/serve.go
@@ -20,6 +20,7 @@ var flagsServe = []cli.Flag{
altsrc.NewStringFlag(&cli.StringFlag{Name: "firebase-key-file", Aliases: []string{"F"}, EnvVars: []string{"NTFY_FIREBASE_KEY_FILE"}, Usage: "Firebase credentials file; if set additionally publish to FCM topic"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"C"}, EnvVars: []string{"NTFY_CACHE_FILE"}, Usage: "cache file used for message caching"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "cache-duration", Aliases: []string{"b"}, EnvVars: []string{"NTFY_CACHE_DURATION"}, Value: server.DefaultCacheDuration, Usage: "buffer messages for this time to allow `since` requests"}),
+ altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-cache-dir", EnvVars: []string{"NTFY_ATTACHMENT_CACHE_DIR"}, Usage: "cache directory for attached files"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "keepalive-interval", Aliases: []string{"k"}, EnvVars: []string{"NTFY_KEEPALIVE_INTERVAL"}, Value: server.DefaultKeepaliveInterval, Usage: "interval of keepalive messages"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "manager-interval", Aliases: []string{"m"}, EnvVars: []string{"NTFY_MANAGER_INTERVAL"}, Value: server.DefaultManagerInterval, Usage: "interval of for message pruning and stats printing"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-sender-addr", EnvVars: []string{"NTFY_SMTP_SENDER_ADDR"}, Usage: "SMTP server address (host:port) for outgoing emails"}),
@@ -69,6 +70,7 @@ func execServe(c *cli.Context) error {
firebaseKeyFile := c.String("firebase-key-file")
cacheFile := c.String("cache-file")
cacheDuration := c.Duration("cache-duration")
+ attachmentCacheDir := c.String("attachment-cache-dir")
keepaliveInterval := c.Duration("keepalive-interval")
managerInterval := c.Duration("manager-interval")
smtpSenderAddr := c.String("smtp-sender-addr")
@@ -117,6 +119,7 @@ func execServe(c *cli.Context) error {
conf.FirebaseKeyFile = firebaseKeyFile
conf.CacheFile = cacheFile
conf.CacheDuration = cacheDuration
+ conf.AttachmentCacheDir = attachmentCacheDir
conf.KeepaliveInterval = keepaliveInterval
conf.ManagerInterval = managerInterval
conf.SMTPSenderAddr = smtpSenderAddr
@@ -126,7 +129,7 @@ func execServe(c *cli.Context) error {
conf.SMTPServerListen = smtpServerListen
conf.SMTPServerDomain = smtpServerDomain
conf.SMTPServerAddrPrefix = smtpServerAddrPrefix
- conf.GlobalTopicLimit = globalTopicLimit
+ conf.TotalTopicLimit = globalTopicLimit
conf.VisitorSubscriptionLimit = visitorSubscriptionLimit
conf.VisitorRequestLimitBurst = visitorRequestLimitBurst
conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish
diff --git a/docs/publish.md b/docs/publish.md
index ec017e05..46a5a334 100644
--- a/docs/publish.md
+++ b/docs/publish.md
@@ -886,3 +886,4 @@ and can be passed as **HTTP headers** or **query parameters in the URL**. They a
| `X-Email` | `X-E-Mail`, `Email`, `E-Mail`, `mail`, `e` | E-mail address for [e-mail notifications](#e-mail-notifications) |
| `X-Cache` | `Cache` | Allows disabling [message caching](#message-caching) |
| `X-Firebase` | `Firebase` | Allows disabling [sending to Firebase](#disable-firebase) |
+| `X-UnifiedPush` | `UnifiedPush`, `up` | XXXXXXXXXXXXXXXX |
diff --git a/server/config.go b/server/config.go
index 30f937e9..68a911fb 100644
--- a/server/config.go
+++ b/server/config.go
@@ -14,6 +14,7 @@ const (
DefaultMinDelay = 10 * time.Second
DefaultMaxDelay = 3 * 24 * time.Hour
DefaultMessageLimit = 4096
+ DefaultAttachmentSizeLimit = 5 * 1024 * 1024
DefaultFirebaseKeepaliveInterval = 3 * time.Hour // Not too frequently to save battery
)
@@ -41,6 +42,8 @@ type Config struct {
FirebaseKeyFile string
CacheFile string
CacheDuration time.Duration
+ AttachmentCacheDir string
+ AttachmentSizeLimit int64
KeepaliveInterval time.Duration
ManagerInterval time.Duration
AtSenderInterval time.Duration
@@ -55,7 +58,8 @@ type Config struct {
MessageLimit int
MinDelay time.Duration
MaxDelay time.Duration
- GlobalTopicLimit int
+ TotalTopicLimit int
+ TotalAttachmentSizeLimit int64
VisitorRequestLimitBurst int
VisitorRequestLimitReplenish time.Duration
VisitorEmailLimitBurst int
@@ -75,6 +79,8 @@ func NewConfig() *Config {
FirebaseKeyFile: "",
CacheFile: "",
CacheDuration: DefaultCacheDuration,
+ AttachmentCacheDir: "",
+ AttachmentSizeLimit: DefaultAttachmentSizeLimit,
KeepaliveInterval: DefaultKeepaliveInterval,
ManagerInterval: DefaultManagerInterval,
MessageLimit: DefaultMessageLimit,
@@ -82,7 +88,7 @@ func NewConfig() *Config {
MaxDelay: DefaultMaxDelay,
AtSenderInterval: DefaultAtSenderInterval,
FirebaseKeepaliveInterval: DefaultFirebaseKeepaliveInterval,
- GlobalTopicLimit: DefaultGlobalTopicLimit,
+ TotalTopicLimit: DefaultGlobalTopicLimit,
VisitorRequestLimitBurst: DefaultVisitorRequestLimitBurst,
VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish,
VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst,
diff --git a/server/message.go b/server/message.go
index ad870e09..2c3fb198 100644
--- a/server/message.go
+++ b/server/message.go
@@ -18,14 +18,21 @@ const (
// message represents a message published to a topic
type message struct {
- ID string `json:"id"` // Random message ID
- Time int64 `json:"time"` // Unix time in seconds
- Event string `json:"event"` // One of the above
- Topic string `json:"topic"`
- Priority int `json:"priority,omitempty"`
- Tags []string `json:"tags,omitempty"`
- Title string `json:"title,omitempty"`
- Message string `json:"message,omitempty"`
+ ID string `json:"id"` // Random message ID
+ Time int64 `json:"time"` // Unix time in seconds
+ Event string `json:"event"` // One of the above
+ Topic string `json:"topic"`
+ Priority int `json:"priority,omitempty"`
+ Tags []string `json:"tags,omitempty"`
+ Title string `json:"title,omitempty"`
+ Message string `json:"message,omitempty"`
+ Attachment *attachment `json:"attachment,omitempty"`
+}
+
+type attachment struct {
+ Name string `json:"name"`
+ Type string `json:"type"`
+ URL string `json:"url"`
}
// messageEncoder is a function that knows how to encode a message
diff --git a/server/server.go b/server/server.go
index 9cf76dea..b8ca70f7 100644
--- a/server/server.go
+++ b/server/server.go
@@ -15,14 +15,18 @@ import (
"html/template"
"io"
"log"
+ "mime"
"net"
"net/http"
"net/http/httptest"
+ "os"
+ "path/filepath"
"regexp"
"strconv"
"strings"
"sync"
"time"
+ "unicode/utf8"
)
// TODO add "max messages in a topic" limit
@@ -96,7 +100,8 @@ var (
staticRegex = regexp.MustCompile(`^/static/.+`)
docsRegex = regexp.MustCompile(`^/docs(|/.*)$`)
- disallowedTopics = []string{"docs", "static"}
+ fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`)
+ disallowedTopics = []string{"docs", "static", "file"}
templateFnMap = template.FuncMap{
"durationToHuman": util.DurationToHuman,
@@ -117,22 +122,26 @@ var (
docsStaticFs embed.FS
docsStaticCached = &util.CachingEmbedFS{ModTime: time.Now(), FS: docsStaticFs}
- errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""}
- errHTTPTooManyRequestsLimitRequests = &errHTTP{42901, http.StatusTooManyRequests, "limit reached: too many requests, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
- errHTTPTooManyRequestsLimitEmails = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
- errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
- errHTTPTooManyRequestsLimitGlobalTopics = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"}
- errHTTPBadRequestEmailDisabled = &errHTTP{40001, http.StatusBadRequest, "e-mail notifications are not enabled", "https://ntfy.sh/docs/config/#e-mail-notifications"}
- errHTTPBadRequestDelayNoCache = &errHTTP{40002, http.StatusBadRequest, "cannot disable cache for delayed message", ""}
- errHTTPBadRequestDelayNoEmail = &errHTTP{40003, http.StatusBadRequest, "delayed e-mail notifications are not supported", ""}
- errHTTPBadRequestDelayCannotParse = &errHTTP{40004, http.StatusBadRequest, "invalid delay parameter: unable to parse delay", "https://ntfy.sh/docs/publish/#scheduled-delivery"}
- errHTTPBadRequestDelayTooSmall = &errHTTP{40005, http.StatusBadRequest, "invalid delay parameter: too small, please refer to the docs", "https://ntfy.sh/docs/publish/#scheduled-delivery"}
- errHTTPBadRequestDelayTooLarge = &errHTTP{40006, http.StatusBadRequest, "invalid delay parameter: too large, please refer to the docs", "https://ntfy.sh/docs/publish/#scheduled-delivery"}
- errHTTPBadRequestPriorityInvalid = &errHTTP{40007, http.StatusBadRequest, "invalid priority parameter", "https://ntfy.sh/docs/publish/#message-priority"}
- errHTTPBadRequestSinceInvalid = &errHTTP{40008, http.StatusBadRequest, "invalid since parameter", "https://ntfy.sh/docs/subscribe/api/#fetch-cached-messages"}
- errHTTPBadRequestTopicInvalid = &errHTTP{40009, http.StatusBadRequest, "invalid topic: path invalid", ""}
- errHTTPBadRequestTopicDisallowed = &errHTTP{40010, http.StatusBadRequest, "invalid topic: topic name is disallowed", ""}
- errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
+ errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""}
+ errHTTPTooManyRequestsLimitRequests = &errHTTP{42901, http.StatusTooManyRequests, "limit reached: too many requests, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
+ errHTTPTooManyRequestsLimitEmails = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
+ errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
+ errHTTPTooManyRequestsLimitGlobalTopics = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"}
+ errHTTPBadRequestEmailDisabled = &errHTTP{40001, http.StatusBadRequest, "e-mail notifications are not enabled", "https://ntfy.sh/docs/config/#e-mail-notifications"}
+ errHTTPBadRequestDelayNoCache = &errHTTP{40002, http.StatusBadRequest, "cannot disable cache for delayed message", ""}
+ errHTTPBadRequestDelayNoEmail = &errHTTP{40003, http.StatusBadRequest, "delayed e-mail notifications are not supported", ""}
+ errHTTPBadRequestDelayCannotParse = &errHTTP{40004, http.StatusBadRequest, "invalid delay parameter: unable to parse delay", "https://ntfy.sh/docs/publish/#scheduled-delivery"}
+ errHTTPBadRequestDelayTooSmall = &errHTTP{40005, http.StatusBadRequest, "invalid delay parameter: too small, please refer to the docs", "https://ntfy.sh/docs/publish/#scheduled-delivery"}
+ errHTTPBadRequestDelayTooLarge = &errHTTP{40006, http.StatusBadRequest, "invalid delay parameter: too large, please refer to the docs", "https://ntfy.sh/docs/publish/#scheduled-delivery"}
+ errHTTPBadRequestPriorityInvalid = &errHTTP{40007, http.StatusBadRequest, "invalid priority parameter", "https://ntfy.sh/docs/publish/#message-priority"}
+ errHTTPBadRequestSinceInvalid = &errHTTP{40008, http.StatusBadRequest, "invalid since parameter", "https://ntfy.sh/docs/subscribe/api/#fetch-cached-messages"}
+ errHTTPBadRequestTopicInvalid = &errHTTP{40009, http.StatusBadRequest, "invalid topic: path invalid", ""}
+ errHTTPBadRequestTopicDisallowed = &errHTTP{40010, http.StatusBadRequest, "invalid topic: topic name is disallowed", ""}
+ errHTTPBadRequestAttachmentsDisallowed = &errHTTP{40011, http.StatusBadRequest, "attachments disallowed", ""}
+ errHTTPBadRequestAttachmentsPublishDisallowed = &errHTTP{40011, http.StatusBadRequest, "invalid message: invalid encoding or too large, and attachments are not allowed", ""}
+ errHTTPBadRequestMessageTooLarge = &errHTTP{40013, http.StatusBadRequest, "invalid message: too large", ""}
+ errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
+ errHTTPInternalErrorInvalidFilePath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid file path", ""}
)
const (
@@ -163,6 +172,11 @@ func New(conf *Config) (*Server, error) {
if err != nil {
return nil, err
}
+ if conf.AttachmentCacheDir != "" {
+ if err := os.MkdirAll(conf.AttachmentCacheDir, 0700); err != nil {
+ return nil, err
+ }
+ }
return &Server{
config: conf,
cache: cache,
@@ -302,6 +316,8 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
return s.handleStatic(w, r)
} else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
return s.handleDocs(w, r)
+ } else if r.Method == http.MethodGet && fileRegex.MatchString(r.URL.Path) {
+ return s.handleFile(w, r)
} else if r.Method == http.MethodOptions {
return s.handleOptions(w, r)
} else if r.Method == http.MethodGet && topicPathRegex.MatchString(r.URL.Path) {
@@ -357,17 +373,45 @@ func (s *Server) handleDocs(w http.ResponseWriter, r *http.Request) error {
return nil
}
+func (s *Server) handleFile(w http.ResponseWriter, r *http.Request) error {
+ if s.config.AttachmentCacheDir == "" {
+ return errHTTPBadRequestAttachmentsDisallowed
+ }
+ matches := fileRegex.FindStringSubmatch(r.URL.Path)
+ if len(matches) != 2 {
+ return errHTTPInternalErrorInvalidFilePath
+ }
+ messageID := matches[1]
+ file := filepath.Join(s.config.AttachmentCacheDir, messageID)
+ stat, err := os.Stat(file)
+ if err != nil {
+ return errHTTPNotFound
+ }
+ w.Header().Set("Length", fmt.Sprintf("%d", stat.Size()))
+ f, err := os.Open(file)
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+ _, err = io.Copy(util.NewContentTypeWriter(w), f)
+ return err
+}
+
func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error {
t, err := s.topicFromPath(r.URL.Path)
if err != nil {
return err
}
- reader := io.LimitReader(r.Body, int64(s.config.MessageLimit))
- b, err := io.ReadAll(reader)
+ body, err := util.Peak(r.Body, s.config.MessageLimit)
if err != nil {
return err
}
- m := newDefaultMessage(t.ID, strings.TrimSpace(string(b)))
+ m := newDefaultMessage(t.ID, "")
+ if !body.LimitReached && utf8.Valid(body.PeakedBytes) {
+ m.Message = strings.TrimSpace(string(body.PeakedBytes))
+ } else if err := s.writeAttachment(v, m, body); err != nil {
+ return err
+ }
cache, firebase, email, err := s.parsePublishParams(r, m)
if err != nil {
return err
@@ -478,6 +522,48 @@ func readParam(r *http.Request, names ...string) string {
return ""
}
+func (s *Server) writeAttachment(v *visitor, m *message, body *util.PeakedReadCloser) error {
+ if s.config.AttachmentCacheDir == "" || !util.FileExists(s.config.AttachmentCacheDir) {
+ return errHTTPBadRequestAttachmentsPublishDisallowed
+ }
+ contentType := http.DetectContentType(body.PeakedBytes)
+ exts, err := mime.ExtensionsByType(contentType)
+ if err != nil {
+ return err
+ }
+ ext := ".bin"
+ if len(exts) > 0 {
+ ext = exts[0]
+ }
+ filename := fmt.Sprintf("attachment%s", ext)
+ file := filepath.Join(s.config.AttachmentCacheDir, m.ID)
+ f, err := os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+ fileSizeLimiter := util.NewLimiter(s.config.AttachmentSizeLimit)
+ limitWriter := util.NewLimitWriter(f, fileSizeLimiter)
+ if _, err := io.Copy(limitWriter, body); err != nil {
+ os.Remove(file)
+ if err == util.ErrLimitReached {
+ return errHTTPBadRequestMessageTooLarge
+ }
+ return err
+ }
+ if err := f.Close(); err != nil {
+ os.Remove(file)
+ return err
+ }
+ m.Message = fmt.Sprintf("You received a file: %s", filename)
+ m.Attachment = &attachment{
+ Name: filename,
+ Type: contentType,
+ URL: fmt.Sprintf("%s/file/%s%s", s.config.BaseURL, m.ID, ext),
+ }
+ return nil
+}
+
func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *visitor) error {
encoder := func(msg *message) (string, error) {
var buf bytes.Buffer
@@ -691,7 +777,7 @@ func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
return nil, errHTTPBadRequestTopicDisallowed
}
if _, ok := s.topics[id]; !ok {
- if len(s.topics) >= s.config.GlobalTopicLimit {
+ if len(s.topics) >= s.config.TotalTopicLimit {
return nil, errHTTPTooManyRequestsLimitGlobalTopics
}
s.topics[id] = newTopic(id)
diff --git a/server/server_test.go b/server/server_test.go
index e713e604..f8a1a8a2 100644
--- a/server/server_test.go
+++ b/server/server_test.go
@@ -165,17 +165,8 @@ func TestServer_PublishLargeMessage(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
body := strings.Repeat("this is a large message", 5000)
- truncated := body[0:4096]
response := request(t, s, "PUT", "/mytopic", body, nil)
- msg := toMessage(t, response.Body.String())
- require.NotEmpty(t, msg.ID)
- require.Equal(t, truncated, msg.Message)
- require.Equal(t, 4096, len(msg.Message))
-
- response = request(t, s, "GET", "/mytopic/json?poll=1", "", nil)
- messages := toMessages(t, response.Body.String())
- require.Equal(t, 1, len(messages))
- require.Equal(t, truncated, messages[0].Message)
+ require.Equal(t, 400, response.Code)
}
func TestServer_PublishPriority(t *testing.T) {
diff --git a/util/content_type_writer.go b/util/content_type_writer.go
new file mode 100644
index 00000000..fb3c43f8
--- /dev/null
+++ b/util/content_type_writer.go
@@ -0,0 +1,41 @@
+package util
+
+import (
+ "net/http"
+ "strings"
+)
+
+// ContentTypeWriter is an implementation of http.ResponseWriter that will detect the content type and set the
+// Content-Type and (optionally) Content-Disposition headers accordingly.
+//
+// It will always set a Content-Type based on http.DetectContentType, but will never send the "text/html"
+// content type.
+type ContentTypeWriter struct {
+ w http.ResponseWriter
+ sniffed bool
+}
+
+// NewContentTypeWriter creates a new ContentTypeWriter
+func NewContentTypeWriter(w http.ResponseWriter) *ContentTypeWriter {
+ return &ContentTypeWriter{w, false}
+}
+
+func (w *ContentTypeWriter) Write(p []byte) (n int, err error) {
+ if w.sniffed {
+ return w.w.Write(p)
+ }
+ // Detect and set Content-Type header
+ // Fix content types that we don't want to inline-render in the browser. In particular,
+ // we don't want to render HTML in the browser for security reasons.
+ contentType := http.DetectContentType(p)
+ if strings.HasPrefix(contentType, "text/html") {
+ contentType = strings.ReplaceAll(contentType, "text/html", "text/plain")
+ } else if contentType == "application/octet-stream" {
+ contentType = "" // Reset to let downstream http.ResponseWriter take care of it
+ }
+ if contentType != "" {
+ w.w.Header().Set("Content-Type", contentType)
+ }
+ w.sniffed = true
+ return w.w.Write(p)
+}
diff --git a/util/content_type_writer_test.go b/util/content_type_writer_test.go
new file mode 100644
index 00000000..08dd751b
--- /dev/null
+++ b/util/content_type_writer_test.go
@@ -0,0 +1,50 @@
+package util
+
+import (
+ "crypto/rand"
+ "github.com/stretchr/testify/require"
+ "net/http/httptest"
+ "testing"
+)
+
+func TestSniffWriter_WriteHTML(t *testing.T) {
+ rr := httptest.NewRecorder()
+ sw := NewContentTypeWriter(rr)
+ sw.Write([]byte(""))
+ require.Equal(t, "text/plain; charset=utf-8", rr.Header().Get("Content-Type"))
+}
+
+func TestSniffWriter_WriteTwoWriteCalls(t *testing.T) {
+ rr := httptest.NewRecorder()
+ sw := NewContentTypeWriter(rr)
+ sw.Write([]byte{0x25, 0x50, 0x44, 0x46, 0x2d, 0x11, 0x22, 0x33})
+ sw.Write([]byte(""))
+ require.Equal(t, "application/pdf", rr.Header().Get("Content-Type"))
+}
+
+func TestSniffWriter_NoSniffWriterWriteHTML(t *testing.T) {
+ // This test just makes sure that without the sniff-w, we would get text/html
+
+ rr := httptest.NewRecorder()
+ rr.Write([]byte(""))
+ require.Equal(t, "text/html; charset=utf-8", rr.Header().Get("Content-Type"))
+}
+
+func TestSniffWriter_WriteHTMLSplitIntoTwoWrites(t *testing.T) {
+ // This test shows how splitting the HTML into two Write() calls will still yield text/plain
+
+ rr := httptest.NewRecorder()
+ sw := NewContentTypeWriter(rr)
+ sw.Write([]byte("alert('hi')"))
+ require.Equal(t, "text/plain; charset=utf-8", rr.Header().Get("Content-Type"))
+}
+
+func TestSniffWriter_WriteUnknownMimeType(t *testing.T) {
+ rr := httptest.NewRecorder()
+ sw := NewContentTypeWriter(rr)
+ randomBytes := make([]byte, 199)
+ rand.Read(randomBytes)
+ sw.Write(randomBytes)
+ require.Equal(t, "application/octet-stream", rr.Header().Get("Content-Type"))
+}
diff --git a/util/limit.go b/util/limit.go
index e5561247..bac3c155 100644
--- a/util/limit.go
+++ b/util/limit.go
@@ -2,6 +2,7 @@ package util
import (
"errors"
+ "io"
"sync"
)
@@ -58,3 +59,43 @@ func (l *Limiter) Value() int64 {
defer l.mu.Unlock()
return l.value
}
+
+// Limit returns the defined limit
+func (l *Limiter) Limit() int64 {
+ return l.limit
+}
+
+// LimitWriter implements an io.Writer that will pass through all Write calls to the underlying
+// writer w until any of the limiter's limit is reached, at which point a Write will return ErrLimitReached.
+// Each limiter's value is increased with every write.
+type LimitWriter struct {
+ w io.Writer
+ written int64
+ limiters []*Limiter
+ mu sync.Mutex
+}
+
+// NewLimitWriter creates a new LimitWriter
+func NewLimitWriter(w io.Writer, limiters ...*Limiter) *LimitWriter {
+ return &LimitWriter{
+ w: w,
+ limiters: limiters,
+ }
+}
+
+// Write passes through all writes to the underlying writer until any of the given limiter's limit is reached
+func (w *LimitWriter) Write(p []byte) (n int, err error) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ for i := 0; i < len(w.limiters); i++ {
+ if err := w.limiters[i].Add(int64(len(p))); err != nil {
+ for j := i - 1; j >= 0; j-- {
+ w.limiters[j].Sub(int64(len(p)))
+ }
+ return 0, ErrLimitReached
+ }
+ }
+ n, err = w.w.Write(p)
+ w.written += int64(n)
+ return
+}
diff --git a/util/limit_test.go b/util/limit_test.go
index f6d56c6d..4f07e00f 100644
--- a/util/limit_test.go
+++ b/util/limit_test.go
@@ -1,6 +1,7 @@
package util
import (
+ "bytes"
"testing"
)
@@ -17,14 +18,68 @@ func TestLimiter_Add(t *testing.T) {
}
}
-func TestLimiter_AddSub(t *testing.T) {
+func TestLimiter_AddSet(t *testing.T) {
l := NewLimiter(10)
l.Add(5)
if l.Value() != 5 {
t.Fatalf("expected value to be %d, got %d", 5, l.Value())
}
- l.Sub(2)
- if l.Value() != 3 {
- t.Fatalf("expected value to be %d, got %d", 3, l.Value())
+ l.Set(7)
+ if l.Value() != 7 {
+ t.Fatalf("expected value to be %d, got %d", 7, l.Value())
+ }
+}
+
+func TestLimitWriter_WriteNoLimiter(t *testing.T) {
+ var buf bytes.Buffer
+ lw := NewLimitWriter(&buf)
+ if _, err := lw.Write(make([]byte, 10)); err != nil {
+ t.Fatal(err)
+ }
+ if _, err := lw.Write(make([]byte, 1)); err != nil {
+ t.Fatal(err)
+ }
+ if buf.Len() != 11 {
+ t.Fatalf("expected buffer length to be %d, got %d", 11, buf.Len())
+ }
+}
+
+func TestLimitWriter_WriteOneLimiter(t *testing.T) {
+ var buf bytes.Buffer
+ l := NewLimiter(10)
+ lw := NewLimitWriter(&buf, l)
+ if _, err := lw.Write(make([]byte, 10)); err != nil {
+ t.Fatal(err)
+ }
+ if _, err := lw.Write(make([]byte, 1)); err != ErrLimitReached {
+ t.Fatalf("expected ErrLimitReached, got %#v", err)
+ }
+ if buf.Len() != 10 {
+ t.Fatalf("expected buffer length to be %d, got %d", 10, buf.Len())
+ }
+ if l.Value() != 10 {
+ t.Fatalf("expected limiter value to be %d, got %d", 10, l.Value())
+ }
+}
+
+func TestLimitWriter_WriteTwoLimiters(t *testing.T) {
+ var buf bytes.Buffer
+ l1 := NewLimiter(11)
+ l2 := NewLimiter(9)
+ lw := NewLimitWriter(&buf, l1, l2)
+ if _, err := lw.Write(make([]byte, 8)); err != nil {
+ t.Fatal(err)
+ }
+ if _, err := lw.Write(make([]byte, 2)); err != ErrLimitReached {
+ t.Fatalf("expected ErrLimitReached, got %#v", err)
+ }
+ if buf.Len() != 8 {
+ t.Fatalf("expected buffer length to be %d, got %d", 8, buf.Len())
+ }
+ if l1.Value() != 8 {
+ t.Fatalf("expected limiter 1 value to be %d, got %d", 8, l1.Value())
+ }
+ if l2.Value() != 8 {
+ t.Fatalf("expected limiter 2 value to be %d, got %d", 8, l2.Value())
}
}
diff --git a/util/peak.go b/util/peak.go
new file mode 100644
index 00000000..100c269b
--- /dev/null
+++ b/util/peak.go
@@ -0,0 +1,61 @@
+package util
+
+import (
+ "bytes"
+ "io"
+ "strings"
+)
+
+// PeakedReadCloser is a ReadCloser that allows peaking into a stream and buffering it in memory.
+// It can be instantiated using the Peak function. After a stream has been peaked, it can still be fully
+// read by reading the PeakedReadCloser. It first drained from the memory buffer, and then from the remaining
+// underlying reader.
+type PeakedReadCloser struct {
+ PeakedBytes []byte
+ LimitReached bool
+ peaked io.Reader
+ underlying io.ReadCloser
+ closed bool
+}
+
+// Peak reads the underlying ReadCloser into memory up until the limit and returns a PeakedReadCloser
+func Peak(underlying io.ReadCloser, limit int) (*PeakedReadCloser, error) {
+ if underlying == nil {
+ underlying = io.NopCloser(strings.NewReader(""))
+ }
+ peaked := make([]byte, limit)
+ read, err := io.ReadFull(underlying, peaked)
+ if err != nil && err != io.ErrUnexpectedEOF && err != io.EOF {
+ return nil, err
+ }
+ return &PeakedReadCloser{
+ PeakedBytes: peaked[:read],
+ LimitReached: read == limit,
+ underlying: underlying,
+ peaked: bytes.NewReader(peaked[:read]),
+ closed: false,
+ }, nil
+}
+
+// Read reads from the peaked bytes and then from the underlying stream
+func (r *PeakedReadCloser) Read(p []byte) (n int, err error) {
+ if r.closed {
+ return 0, io.EOF
+ }
+ n, err = r.peaked.Read(p)
+ if err == io.EOF {
+ return r.underlying.Read(p)
+ } else if err != nil {
+ return 0, err
+ }
+ return
+}
+
+// Close closes the underlying stream
+func (r *PeakedReadCloser) Close() error {
+ if r.closed {
+ return io.EOF
+ }
+ r.closed = true
+ return r.underlying.Close()
+}
diff --git a/util/peak_test.go b/util/peak_test.go
new file mode 100644
index 00000000..76995179
--- /dev/null
+++ b/util/peak_test.go
@@ -0,0 +1,55 @@
+package util
+
+import (
+ "github.com/stretchr/testify/require"
+ "io"
+ "strings"
+ "testing"
+)
+
+func TestPeak_LimitReached(t *testing.T) {
+ underlying := io.NopCloser(strings.NewReader("1234567890"))
+ peaked, err := Peak(underlying, 5)
+ if err != nil {
+ t.Fatal(err)
+ }
+ require.Equal(t, []byte("12345"), peaked.PeakedBytes)
+ require.Equal(t, true, peaked.LimitReached)
+
+ all, err := io.ReadAll(peaked)
+ if err != nil {
+ t.Fatal(err)
+ }
+ require.Equal(t, []byte("1234567890"), all)
+ require.Equal(t, []byte("12345"), peaked.PeakedBytes)
+ require.Equal(t, true, peaked.LimitReached)
+}
+
+func TestPeak_LimitNotReached(t *testing.T) {
+ underlying := io.NopCloser(strings.NewReader("1234567890"))
+ peaked, err := Peak(underlying, 15)
+ if err != nil {
+ t.Fatal(err)
+ }
+ all, err := io.ReadAll(peaked)
+ if err != nil {
+ t.Fatal(err)
+ }
+ require.Equal(t, []byte("1234567890"), all)
+ require.Equal(t, []byte("1234567890"), peaked.PeakedBytes)
+ require.Equal(t, false, peaked.LimitReached)
+}
+
+func TestPeak_Nil(t *testing.T) {
+ peaked, err := Peak(nil, 15)
+ if err != nil {
+ t.Fatal(err)
+ }
+ all, err := io.ReadAll(peaked)
+ if err != nil {
+ t.Fatal(err)
+ }
+ require.Equal(t, []byte(""), all)
+ require.Equal(t, []byte(""), peaked.PeakedBytes)
+ require.Equal(t, false, peaked.LimitReached)
+}