From aba7e86cbc8786bfe3daa5288ec30e4bf78137d9 Mon Sep 17 00:00:00 2001 From: Philipp Heckel Date: Sun, 3 Apr 2022 12:39:52 -0400 Subject: [PATCH] Attachment behavior fix for Firefox --- server/message_cache.go | 2 +- server/message_cache_test.go | 4 +- server/server.go | 56 ++++++++++++-------- server/server_test.go | 6 +-- server/visitor.go | 28 +++++++++- util/limit.go | 16 ++++++ util/peak.go | 61 ---------------------- util/peek.go | 61 ++++++++++++++++++++++ util/{peak_test.go => peek_test.go} | 14 ++--- web/src/app/Api.js | 12 ++++- web/src/app/utils.js | 1 + web/src/components/App.js | 6 +-- web/src/components/SendDialog.js | 79 +++++++++++++++++++++-------- 13 files changed, 223 insertions(+), 123 deletions(-) delete mode 100644 util/peak.go create mode 100644 util/peek.go rename util/{peak_test.go => peek_test.go} (74%) diff --git a/server/message_cache.go b/server/message_cache.go index 4a48ac1a..cd503068 100644 --- a/server/message_cache.go +++ b/server/message_cache.go @@ -355,7 +355,7 @@ func (c *messageCache) Prune(olderThan time.Time) error { return err } -func (c *messageCache) AttachmentsSize(owner string) (int64, error) { +func (c *messageCache) AttachmentBytesUsed(owner string) (int64, error) { rows, err := c.db.Query(selectAttachmentsSizeQuery, owner, time.Now().Unix()) if err != nil { return 0, err diff --git a/server/message_cache_test.go b/server/message_cache_test.go index aea71c73..cb888b42 100644 --- a/server/message_cache_test.go +++ b/server/message_cache_test.go @@ -337,11 +337,11 @@ func testCacheAttachments(t *testing.T, c *messageCache) { require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL) require.Equal(t, "1.2.3.4", messages[1].Attachment.Owner) - size, err := c.AttachmentsSize("1.2.3.4") + size, err := c.AttachmentBytesUsed("1.2.3.4") require.Nil(t, err) require.Equal(t, int64(30000), size) - size, err = c.AttachmentsSize("5.6.7.8") + size, err = c.AttachmentBytesUsed("5.6.7.8") require.Nil(t, err) require.Equal(t, int64(0), size) diff --git a/server/server.go b/server/server.go index b041f697..c492eea9 100644 --- a/server/server.go +++ b/server/server.go @@ -66,6 +66,7 @@ var ( publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/(publish|send|trigger)$`) webConfigPath = "/config.js" + userStatsPath = "/user/stats" staticRegex = regexp.MustCompile(`^/static/.+`) docsRegex = regexp.MustCompile(`^/docs(|/.*)$`) fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`) @@ -269,6 +270,8 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit return s.handleEmpty(w, r, v) } else if r.Method == http.MethodGet && r.URL.Path == webConfigPath { return s.handleWebConfig(w, r) + } else if r.Method == http.MethodGet && r.URL.Path == userStatsPath { + return s.handleUserStats(w, r, v) } else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) { return s.handleStatic(w, r) } else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) { @@ -351,6 +354,19 @@ var config = { return err } +func (s *Server) handleUserStats(w http.ResponseWriter, r *http.Request, v *visitor) error { + stats, err := v.Stats() + if err != nil { + return err + } + w.Header().Set("Content-Type", "text/json") + w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests + if err := json.NewEncoder(w).Encode(stats); err != nil { + return err + } + return nil +} + func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request) error { r.URL.Path = webSiteDir + r.URL.Path util.Gzip(http.FileServer(http.FS(webFsCached))).ServeHTTP(w, r) @@ -395,8 +411,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito if err != nil { return err } - return errHTTPEntityTooLargeAttachmentTooLarge - body, err := util.Peak(r.Body, s.config.MessageLimit) + body, err := util.Peek(r.Body, s.config.MessageLimit) if err != nil { return err } @@ -540,35 +555,35 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca // If file.txt is <= 4096 (message limit) and valid UTF-8, treat it as a message // 5. curl -T file.txt ntfy.sh/mytopic // If file.txt is > message limit, treat it as an attachment -func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *message, body *util.PeakedReadCloser, unifiedpush bool) error { +func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser, unifiedpush bool) error { if unifiedpush { return s.handleBodyAsMessageAutoDetect(m, body) // Case 1 } else if m.Attachment != nil && m.Attachment.URL != "" { return s.handleBodyAsTextMessage(m, body) // Case 2 } else if m.Attachment != nil && m.Attachment.Name != "" { return s.handleBodyAsAttachment(r, v, m, body) // Case 3 - } else if !body.LimitReached && utf8.Valid(body.PeakedBytes) { + } else if !body.LimitReached && utf8.Valid(body.PeekedBytes) { return s.handleBodyAsTextMessage(m, body) // Case 4 } return s.handleBodyAsAttachment(r, v, m, body) // Case 5 } -func (s *Server) handleBodyAsMessageAutoDetect(m *message, body *util.PeakedReadCloser) error { - if utf8.Valid(body.PeakedBytes) { - m.Message = string(body.PeakedBytes) // Do not trim +func (s *Server) handleBodyAsMessageAutoDetect(m *message, body *util.PeekedReadCloser) error { + if utf8.Valid(body.PeekedBytes) { + m.Message = string(body.PeekedBytes) // Do not trim } else { - m.Message = base64.StdEncoding.EncodeToString(body.PeakedBytes) + m.Message = base64.StdEncoding.EncodeToString(body.PeekedBytes) m.Encoding = encodingBase64 } return nil } -func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeakedReadCloser) error { - if !utf8.Valid(body.PeakedBytes) { +func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeekedReadCloser) error { + if !utf8.Valid(body.PeekedBytes) { return errHTTPBadRequestMessageNotUTF8 } - if len(body.PeakedBytes) > 0 { // Empty body should not override message (publish via GET!) - m.Message = strings.TrimSpace(string(body.PeakedBytes)) // Truncates the message to the peak limit if required + if len(body.PeekedBytes) > 0 { // Empty body should not override message (publish via GET!) + m.Message = strings.TrimSpace(string(body.PeekedBytes)) // Truncates the message to the peek limit if required } if m.Attachment != nil && m.Attachment.Name != "" && m.Message == "" { m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name) @@ -576,21 +591,20 @@ func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeakedReadCloser return nil } -func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, body *util.PeakedReadCloser) error { +func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser) error { if s.fileCache == nil || s.config.BaseURL == "" || s.config.AttachmentCacheDir == "" { return errHTTPBadRequestAttachmentsDisallowed } else if m.Time > time.Now().Add(s.config.AttachmentExpiryDuration).Unix() { return errHTTPBadRequestAttachmentsExpiryBeforeDelivery } - visitorAttachmentsSize, err := s.messageCache.AttachmentsSize(v.ip) + visitorStats, err := v.Stats() if err != nil { return err } - remainingVisitorAttachmentSize := s.config.VisitorAttachmentTotalSizeLimit - visitorAttachmentsSize contentLengthStr := r.Header.Get("Content-Length") if contentLengthStr != "" { // Early "do-not-trust" check, hard limit see below contentLength, err := strconv.ParseInt(contentLengthStr, 10, 64) - if err == nil && (contentLength > remainingVisitorAttachmentSize || contentLength > s.config.AttachmentFileSizeLimit) { + if err == nil && (contentLength > visitorStats.VisitorAttachmentBytesRemaining || contentLength > s.config.AttachmentFileSizeLimit) { return errHTTPEntityTooLargeAttachmentTooLarge } } @@ -600,7 +614,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, var ext string m.Attachment.Owner = v.ip // Important for attachment rate limiting m.Attachment.Expires = time.Now().Add(s.config.AttachmentExpiryDuration).Unix() - m.Attachment.Type, ext = util.DetectContentType(body.PeakedBytes, m.Attachment.Name) + m.Attachment.Type, ext = util.DetectContentType(body.PeekedBytes, m.Attachment.Name) m.Attachment.URL = fmt.Sprintf("%s/file/%s%s", s.config.BaseURL, m.ID, ext) if m.Attachment.Name == "" { m.Attachment.Name = fmt.Sprintf("attachment%s", ext) @@ -608,7 +622,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, if m.Message == "" { m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name) } - m.Attachment.Size, err = s.fileCache.Write(m.ID, body, v.BandwidthLimiter(), util.NewFixedLimiter(remainingVisitorAttachmentSize)) + m.Attachment.Size, err = s.fileCache.Write(m.ID, body, v.BandwidthLimiter(), util.NewFixedLimiter(visitorStats.VisitorAttachmentBytesRemaining)) if err == util.ErrLimitReached { return errHTTPEntityTooLargeAttachmentTooLarge } else if err != nil { @@ -1097,11 +1111,11 @@ func (s *Server) limitRequests(next handleFunc) handleFunc { } } -// transformBodyJSON peaks the request body, reads the JSON, and converts it to headers +// transformBodyJSON peeks the request body, reads the JSON, and converts it to headers // before passing it on to the next handler. This is meant to be used in combination with handlePublish. func (s *Server) transformBodyJSON(next handleFunc) handleFunc { return func(w http.ResponseWriter, r *http.Request, v *visitor) error { - body, err := util.Peak(r.Body, s.config.MessageLimit) + body, err := util.Peek(r.Body, s.config.MessageLimit) if err != nil { return err } @@ -1217,7 +1231,7 @@ func (s *Server) visitor(r *http.Request) *visitor { } v, exists := s.visitors[ip] if !exists { - s.visitors[ip] = newVisitor(s.config, ip) + s.visitors[ip] = newVisitor(s.config, s.messageCache, ip) return s.visitors[ip] } v.Keepalive() diff --git a/server/server_test.go b/server/server_test.go index 67aff540..ff1a9000 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -938,7 +938,7 @@ func TestServer_PublishAttachment(t *testing.T) { require.Equal(t, content, response.Body.String()) // Slightly unrelated cross-test: make sure we add an owner for internal attachments - size, err := s.messageCache.AttachmentsSize("9.9.9.9") // See request() + size, err := s.messageCache.AttachmentBytesUsed("9.9.9.9") // See request() require.Nil(t, err) require.Equal(t, int64(5000), size) } @@ -967,7 +967,7 @@ func TestServer_PublishAttachmentShortWithFilename(t *testing.T) { require.Equal(t, content, response.Body.String()) // Slightly unrelated cross-test: make sure we add an owner for internal attachments - size, err := s.messageCache.AttachmentsSize("1.2.3.4") + size, err := s.messageCache.AttachmentBytesUsed("1.2.3.4") require.Nil(t, err) require.Equal(t, int64(21), size) } @@ -987,7 +987,7 @@ func TestServer_PublishAttachmentExternalWithoutFilename(t *testing.T) { require.Equal(t, "", msg.Attachment.Owner) // Slightly unrelated cross-test: make sure we don't add an owner for external attachments - size, err := s.messageCache.AttachmentsSize("127.0.0.1") + size, err := s.messageCache.AttachmentBytesUsed("127.0.0.1") require.Nil(t, err) require.Equal(t, int64(0), size) } diff --git a/server/visitor.go b/server/visitor.go index 948fe44c..58cc28ab 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -22,6 +22,7 @@ var ( // visitor represents an API user, and its associated rate.Limiter used for rate limiting type visitor struct { config *Config + messageCache *messageCache ip string requests *rate.Limiter emails *rate.Limiter @@ -31,9 +32,17 @@ type visitor struct { mu sync.Mutex } -func newVisitor(conf *Config, ip string) *visitor { +type visitorStats struct { + AttachmentFileSizeLimit int64 `json:"attachmentFileSizeLimit"` + VisitorAttachmentBytesTotal int64 `json:"visitorAttachmentBytesTotal"` + VisitorAttachmentBytesUsed int64 `json:"visitorAttachmentBytesUsed"` + VisitorAttachmentBytesRemaining int64 `json:"visitorAttachmentBytesRemaining"` +} + +func newVisitor(conf *Config, messageCache *messageCache, ip string) *visitor { return &visitor{ config: conf, + messageCache: messageCache, ip: ip, requests: rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst), emails: rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst), @@ -91,3 +100,20 @@ func (v *visitor) Stale() bool { defer v.mu.Unlock() return time.Since(v.seen) > visitorExpungeAfter } + +func (v *visitor) Stats() (*visitorStats, error) { + attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip) + if err != nil { + return nil, err + } + attachmentsBytesRemaining := v.config.VisitorAttachmentTotalSizeLimit - attachmentsBytesUsed + if attachmentsBytesRemaining < 0 { + attachmentsBytesRemaining = 0 + } + return &visitorStats{ + AttachmentFileSizeLimit: v.config.AttachmentFileSizeLimit, + VisitorAttachmentBytesTotal: v.config.VisitorAttachmentTotalSizeLimit, + VisitorAttachmentBytesUsed: attachmentsBytesUsed, + VisitorAttachmentBytesRemaining: attachmentsBytesRemaining, + }, nil +} diff --git a/util/limit.go b/util/limit.go index 8df768ad..36d8f66f 100644 --- a/util/limit.go +++ b/util/limit.go @@ -15,6 +15,10 @@ var ErrLimitReached = errors.New("limit reached") type Limiter interface { // Allow adds n to the limiters internal value, or returns ErrLimitReached if the limit has been reached Allow(n int64) error + + // Remaining returns the remaining count until the limit is reached; may return -1 if the implementation + // does not support this operation. + Remaining() int64 } // FixedLimiter is a helper that allows adding values up to a well-defined limit. Once the limit is reached @@ -44,6 +48,13 @@ func (l *FixedLimiter) Allow(n int64) error { return nil } +// Remaining returns the remaining count until the limit is reached +func (l *FixedLimiter) Remaining() int64 { + l.mu.Lock() + defer l.mu.Unlock() + return l.limit - l.value +} + // RateLimiter is a Limiter that wraps a rate.Limiter, allowing a floating time-based limit. type RateLimiter struct { limiter *rate.Limiter @@ -74,6 +85,11 @@ func (l *RateLimiter) Allow(n int64) error { return nil } +// Remaining is not implemented for RateLimiter. It always returns -1. +func (l *RateLimiter) Remaining() int64 { + return -1 +} + // LimitWriter implements an io.Writer that will pass through all Write calls to the underlying // writer w until any of the limiter's limit is reached, at which point a Write will return ErrLimitReached. // Each limiter's value is increased with every write. diff --git a/util/peak.go b/util/peak.go deleted file mode 100644 index 100c269b..00000000 --- a/util/peak.go +++ /dev/null @@ -1,61 +0,0 @@ -package util - -import ( - "bytes" - "io" - "strings" -) - -// PeakedReadCloser is a ReadCloser that allows peaking into a stream and buffering it in memory. -// It can be instantiated using the Peak function. After a stream has been peaked, it can still be fully -// read by reading the PeakedReadCloser. It first drained from the memory buffer, and then from the remaining -// underlying reader. -type PeakedReadCloser struct { - PeakedBytes []byte - LimitReached bool - peaked io.Reader - underlying io.ReadCloser - closed bool -} - -// Peak reads the underlying ReadCloser into memory up until the limit and returns a PeakedReadCloser -func Peak(underlying io.ReadCloser, limit int) (*PeakedReadCloser, error) { - if underlying == nil { - underlying = io.NopCloser(strings.NewReader("")) - } - peaked := make([]byte, limit) - read, err := io.ReadFull(underlying, peaked) - if err != nil && err != io.ErrUnexpectedEOF && err != io.EOF { - return nil, err - } - return &PeakedReadCloser{ - PeakedBytes: peaked[:read], - LimitReached: read == limit, - underlying: underlying, - peaked: bytes.NewReader(peaked[:read]), - closed: false, - }, nil -} - -// Read reads from the peaked bytes and then from the underlying stream -func (r *PeakedReadCloser) Read(p []byte) (n int, err error) { - if r.closed { - return 0, io.EOF - } - n, err = r.peaked.Read(p) - if err == io.EOF { - return r.underlying.Read(p) - } else if err != nil { - return 0, err - } - return -} - -// Close closes the underlying stream -func (r *PeakedReadCloser) Close() error { - if r.closed { - return io.EOF - } - r.closed = true - return r.underlying.Close() -} diff --git a/util/peek.go b/util/peek.go new file mode 100644 index 00000000..f7219253 --- /dev/null +++ b/util/peek.go @@ -0,0 +1,61 @@ +package util + +import ( + "bytes" + "io" + "strings" +) + +// PeekedReadCloser is a ReadCloser that allows peeking into a stream and buffering it in memory. +// It can be instantiated using the Peek function. After a stream has been peeked, it can still be fully +// read by reading the PeekedReadCloser. It first drained from the memory buffer, and then from the remaining +// underlying reader. +type PeekedReadCloser struct { + PeekedBytes []byte + LimitReached bool + peeked io.Reader + underlying io.ReadCloser + closed bool +} + +// Peek reads the underlying ReadCloser into memory up until the limit and returns a PeekedReadCloser +func Peek(underlying io.ReadCloser, limit int) (*PeekedReadCloser, error) { + if underlying == nil { + underlying = io.NopCloser(strings.NewReader("")) + } + peeked := make([]byte, limit) + read, err := io.ReadFull(underlying, peeked) + if err != nil && err != io.ErrUnexpectedEOF && err != io.EOF { + return nil, err + } + return &PeekedReadCloser{ + PeekedBytes: peeked[:read], + LimitReached: read == limit, + underlying: underlying, + peeked: bytes.NewReader(peeked[:read]), + closed: false, + }, nil +} + +// Read reads from the peeked bytes and then from the underlying stream +func (r *PeekedReadCloser) Read(p []byte) (n int, err error) { + if r.closed { + return 0, io.EOF + } + n, err = r.peeked.Read(p) + if err == io.EOF { + return r.underlying.Read(p) + } else if err != nil { + return 0, err + } + return +} + +// Close closes the underlying stream +func (r *PeekedReadCloser) Close() error { + if r.closed { + return io.EOF + } + r.closed = true + return r.underlying.Close() +} diff --git a/util/peak_test.go b/util/peek_test.go similarity index 74% rename from util/peak_test.go rename to util/peek_test.go index 76995179..e076394c 100644 --- a/util/peak_test.go +++ b/util/peek_test.go @@ -9,11 +9,11 @@ import ( func TestPeak_LimitReached(t *testing.T) { underlying := io.NopCloser(strings.NewReader("1234567890")) - peaked, err := Peak(underlying, 5) + peaked, err := Peek(underlying, 5) if err != nil { t.Fatal(err) } - require.Equal(t, []byte("12345"), peaked.PeakedBytes) + require.Equal(t, []byte("12345"), peaked.PeekedBytes) require.Equal(t, true, peaked.LimitReached) all, err := io.ReadAll(peaked) @@ -21,13 +21,13 @@ func TestPeak_LimitReached(t *testing.T) { t.Fatal(err) } require.Equal(t, []byte("1234567890"), all) - require.Equal(t, []byte("12345"), peaked.PeakedBytes) + require.Equal(t, []byte("12345"), peaked.PeekedBytes) require.Equal(t, true, peaked.LimitReached) } func TestPeak_LimitNotReached(t *testing.T) { underlying := io.NopCloser(strings.NewReader("1234567890")) - peaked, err := Peak(underlying, 15) + peaked, err := Peek(underlying, 15) if err != nil { t.Fatal(err) } @@ -36,12 +36,12 @@ func TestPeak_LimitNotReached(t *testing.T) { t.Fatal(err) } require.Equal(t, []byte("1234567890"), all) - require.Equal(t, []byte("1234567890"), peaked.PeakedBytes) + require.Equal(t, []byte("1234567890"), peaked.PeekedBytes) require.Equal(t, false, peaked.LimitReached) } func TestPeak_Nil(t *testing.T) { - peaked, err := Peak(nil, 15) + peaked, err := Peek(nil, 15) if err != nil { t.Fatal(err) } @@ -50,6 +50,6 @@ func TestPeak_Nil(t *testing.T) { t.Fatal(err) } require.Equal(t, []byte(""), all) - require.Equal(t, []byte(""), peaked.PeakedBytes) + require.Equal(t, []byte(""), peaked.PeekedBytes) require.Equal(t, false, peaked.LimitReached) } diff --git a/web/src/app/Api.js b/web/src/app/Api.js index 56fb9007..8f823cab 100644 --- a/web/src/app/Api.js +++ b/web/src/app/Api.js @@ -7,7 +7,7 @@ import { topicUrl, topicUrlAuth, topicUrlJsonPoll, - topicUrlJsonPollWithSince + topicUrlJsonPollWithSince, userStatsUrl } from "./utils"; import userManager from "./UserManager"; @@ -93,6 +93,16 @@ class Api { } throw new Error(`Unexpected server response ${response.status}`); } + + async userStats(baseUrl) { + const url = userStatsUrl(baseUrl); + console.log(`[Api] Fetching user stats ${url}`); + const response = await fetch(url); + if (response.status !== 200) { + throw new Error(`Unexpected server response ${response.status}`); + } + return response.json(); + } } const api = new Api(); diff --git a/web/src/app/utils.js b/web/src/app/utils.js index 62eee838..cf1398cf 100644 --- a/web/src/app/utils.js +++ b/web/src/app/utils.js @@ -18,6 +18,7 @@ export const topicUrlJsonPoll = (baseUrl, topic) => `${topicUrlJson(baseUrl, top export const topicUrlJsonPollWithSince = (baseUrl, topic, since) => `${topicUrlJson(baseUrl, topic)}?poll=1&since=${since}`; export const topicUrlAuth = (baseUrl, topic) => `${topicUrl(baseUrl, topic)}/auth`; export const topicShortUrl = (baseUrl, topic) => shortUrl(topicUrl(baseUrl, topic)); +export const userStatsUrl = (baseUrl) => `${baseUrl}/user/stats`; export const shortUrl = (url) => url.replaceAll(/https?:\/\//g, ""); export const expandUrl = (url) => [`https://${url}`, `http://${url}`]; export const expandSecureUrl = (url) => `https://${url}`; diff --git a/web/src/components/App.js b/web/src/components/App.js index 4dcb9990..e25c732a 100644 --- a/web/src/components/App.js +++ b/web/src/components/App.js @@ -154,11 +154,6 @@ const Messaging = (props) => { e.preventDefault(); } }; - const handleDrop = (e) => { - e.preventDefault(); - setShowDropZone(false); - console.log(e.dataTransfer.files[0]); - }; return ( <> @@ -173,6 +168,7 @@ const Messaging = (props) => { open={showDialog} dropZone={showDropZone} onClose={handleSendDialogClose} + onDrop={() => setShowDropZone(false)} topicUrl={selectedTopicUrl} message={message} /> diff --git a/web/src/components/SendDialog.js b/web/src/components/SendDialog.js index 488e01f4..4d6eb6fd 100644 --- a/web/src/components/SendDialog.js +++ b/web/src/components/SendDialog.js @@ -26,7 +26,7 @@ import api from "../app/Api"; import userManager from "../app/UserManager"; const SendDialog = (props) => { - const [topicUrl, setTopicUrl] = useState(props.topicUrl); + const [topicUrl, setTopicUrl] = useState(""); const [message, setMessage] = useState(props.message || ""); const [title, setTitle] = useState(""); const [tags, setTags] = useState(""); @@ -40,7 +40,7 @@ const SendDialog = (props) => { const [delay, setDelay] = useState(""); const [publishAnother, setPublishAnother] = useState(false); - const [showTopicUrl, setShowTopicUrl] = useState(props.topicUrl === ""); // FIXME + const [showTopicUrl, setShowTopicUrl] = useState(""); const [showClickUrl, setShowClickUrl] = useState(false); const [showAttachUrl, setShowAttachUrl] = useState(false); const [showEmail, setShowEmail] = useState(false); @@ -48,21 +48,25 @@ const SendDialog = (props) => { const showAttachFile = !!attachFile && !showAttachUrl; const attachFileInput = useRef(); + const [attachFileError, setAttachFileError] = useState(""); const [activeRequest, setActiveRequest] = useState(null); const [statusText, setStatusText] = useState(""); const disabled = !!activeRequest; - const dropZone = props.dropZone; + const [sendButtonEnabled, setSendButtonEnabled] = useState(true); + const dropZone = props.dropZone; const fullScreen = useMediaQuery(theme.breakpoints.down('sm')); - const sendButtonEnabled = (() => { - if (!validTopicUrl(topicUrl)) { - return false; - } - return true; - })(); + useEffect(() => { + setTopicUrl(props.topicUrl); + setShowTopicUrl(props.topicUrl === "") + }, [props.topicUrl]); + + useEffect(() => { + setSendButtonEnabled(validTopicUrl(topicUrl) && !attachFileError); + }, [topicUrl, attachFileError]); const handleSubmit = async () => { const { baseUrl, topic } = splitTopicUrl(topicUrl); @@ -124,23 +128,47 @@ const SendDialog = (props) => { setActiveRequest(null); }; + const checkAttachmentLimits = async (file) => { + try { + const { baseUrl } = splitTopicUrl(topicUrl); + const stats = await api.userStats(baseUrl); + console.log(`[SendDialog] Visitor attachment limits`, stats); + + const fileSizeLimit = stats.attachmentFileSizeLimit ?? 0; + if (fileSizeLimit > 0 && file.size > fileSizeLimit) { + return setAttachFileError(`exceeds ${formatBytes(fileSizeLimit)} limit`); + } + + const remainingBytes = stats.visitorAttachmentBytesRemaining ?? 0; + if (remainingBytes > 0 && file.size > remainingBytes) { + return setAttachFileError(`quota reached, only ${formatBytes(remainingBytes)} remaining`); + } + + setAttachFileError(""); + } catch (e) { + console.log(`[SendDialog] Retrieving attachment limits failed`, e); + setAttachFileError(""); // Reset error (rely on server-side checking) + } + }; + const handleAttachFileClick = () => { attachFileInput.current.click(); }; - const handleAttachFileChanged = (ev) => { - const file = ev.target.files[0]; - setAttachFile(file); - setFilename(file.name); - console.log(ev.target.files[0]); - console.log(URL.createObjectURL(ev.target.files[0])); + const handleAttachFileChanged = async (ev) => { + await updateAttachFile(ev.target.files[0]); }; - const handleDrop = (ev) => { + const handleAttachFileDrop = async (ev) => { ev.preventDefault(); - const file = ev.dataTransfer.files[0]; + props.onDrop(); + await updateAttachFile(ev.dataTransfer.files[0]); + }; + + const updateAttachFile = async (file) => { setAttachFile(file); setFilename(file.name); + await checkAttachmentLimits(file); }; const allowDrag = (ev) => { @@ -178,7 +206,7 @@ const SendDialog = (props) => { justifyContent: "center", alignItems: "center", }} - onDrop={handleDrop} + onDrop={handleAttachFileDrop} onDragEnter={allowDrag} onDragOver={allowDrag} > @@ -360,9 +388,11 @@ const SendDialog = (props) => { file={attachFile} filename={filename} disabled={disabled} + error={attachFileError} onChangeFilename={(f) => setFilename(f)} onClose={() => { setAttachFile(null); + setAttachFileError(""); setFilename(""); }} />} @@ -466,7 +496,7 @@ const AttachmentBox = (props) => { borderRadius: '4px', }}> - + { disabled={props.disabled} />
- {formatBytes(file.size)} -
+ + {formatBytes(file.size)} + {props.error && + + {" "}({props.error}) + + } + +