From 57814cf8555ce2f9677a60e51d0b776cff791f3f Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Thu, 29 Dec 2022 09:57:42 -0500 Subject: [PATCH] Tests --- server/errors.go | 10 ++- server/server.go | 19 ++-- server/server_account.go | 26 +++--- server/server_account_test.go | 157 ++++++++++++++++++++++++++++++++++ server/server_matrix.go | 2 +- server/server_matrix_test.go | 2 +- server/util.go | 13 +++ util/util.go | 17 ++-- 8 files changed, 209 insertions(+), 37 deletions(-) diff --git a/server/errors.go b/server/errors.go index d7a1cdc9..12cdc5aa 100644 --- a/server/errors.go +++ b/server/errors.go @@ -48,19 +48,21 @@ var ( errHTTPBadRequestAttachmentsDisallowed = &errHTTP{40014, http.StatusBadRequest, "invalid request: attachments not allowed", "https://ntfy.sh/docs/config/#attachments"} errHTTPBadRequestAttachmentsExpiryBeforeDelivery = &errHTTP{40015, http.StatusBadRequest, "invalid request: attachment expiry before delayed delivery date", "https://ntfy.sh/docs/publish/#scheduled-delivery"} errHTTPBadRequestWebSocketsUpgradeHeaderMissing = &errHTTP{40016, http.StatusBadRequest, "invalid request: client not using the websocket protocol", "https://ntfy.sh/docs/subscribe/api/#websockets"} - errHTTPBadRequestJSONInvalid = &errHTTP{40017, http.StatusBadRequest, "invalid request: request body must be message JSON", "https://ntfy.sh/docs/publish/#publish-as-json"} + errHTTPBadRequestMessageJSONInvalid = &errHTTP{40017, http.StatusBadRequest, "invalid request: request body must be message JSON", "https://ntfy.sh/docs/publish/#publish-as-json"} errHTTPBadRequestActionsInvalid = &errHTTP{40018, http.StatusBadRequest, "invalid request: actions invalid", "https://ntfy.sh/docs/publish/#action-buttons"} errHTTPBadRequestMatrixMessageInvalid = &errHTTP{40019, http.StatusBadRequest, "invalid request: Matrix JSON invalid", "https://ntfy.sh/docs/publish/#matrix-gateway"} errHTTPBadRequestMatrixPushkeyBaseURLMismatch = &errHTTP{40020, http.StatusBadRequest, "invalid request: push key must be prefixed with base URL", "https://ntfy.sh/docs/publish/#matrix-gateway"} errHTTPBadRequestIconURLInvalid = &errHTTP{40021, http.StatusBadRequest, "invalid request: icon URL is invalid", "https://ntfy.sh/docs/publish/#icons"} errHTTPBadRequestSignupNotEnabled = &errHTTP{40022, http.StatusBadRequest, "invalid request: signup not enabled", "https://ntfy.sh/docs/config"} errHTTPBadRequestNoTokenProvided = &errHTTP{40023, http.StatusBadRequest, "invalid request: no token provided", ""} + errHTTPBadRequestJSONInvalid = &errHTTP{40024, http.StatusBadRequest, "invalid request: request body must be valid JSON", ""} errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""} errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication"} errHTTPForbidden = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication"} errHTTPConflictUserExists = &errHTTP{40901, http.StatusConflict, "conflict: user already exists", ""} - errHTTPEntityTooLargeAttachmentTooLarge = &errHTTP{41301, http.StatusRequestEntityTooLarge, "attachment too large, or bandwidth limit reached", "https://ntfy.sh/docs/publish/#limitations"} - errHTTPEntityTooLargeMatrixRequestTooLarge = &errHTTP{41302, http.StatusRequestEntityTooLarge, "Matrix request is larger than the max allowed length", ""} + errHTTPEntityTooLargeAttachment = &errHTTP{41301, http.StatusRequestEntityTooLarge, "attachment too large, or bandwidth limit reached", "https://ntfy.sh/docs/publish/#limitations"} + errHTTPEntityTooLargeMatrixRequest = &errHTTP{41302, http.StatusRequestEntityTooLarge, "Matrix request is larger than the max allowed length", ""} + errHTTPEntityTooLargeJSONBody = &errHTTP{41303, http.StatusRequestEntityTooLarge, "JSON body too large", ""} errHTTPTooManyRequestsLimitRequests = &errHTTP{42901, http.StatusTooManyRequests, "limit reached: too many requests, please be nice", "https://ntfy.sh/docs/publish/#limitations"} errHTTPTooManyRequestsLimitEmails = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"} errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"} @@ -68,6 +70,6 @@ var ( errHTTPTooManyRequestsAttachmentBandwidthLimit = &errHTTP{42905, http.StatusTooManyRequests, "too many requests: daily bandwidth limit reached", "https://ntfy.sh/docs/publish/#limitations"} errHTTPTooManyRequestsAccountCreateLimit = &errHTTP{42906, http.StatusTooManyRequests, "too many requests: daily account creation limit reached", "https://ntfy.sh/docs/publish/#limitations"} // FIXME document limit errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""} - errHTTPInternalErrorInvalidFilePath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid file path", ""} + errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid file path", ""} errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/"} ) diff --git a/server/server.go b/server/server.go index f8c3655c..be9a9105 100644 --- a/server/server.go +++ b/server/server.go @@ -479,7 +479,7 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) } matches := fileRegex.FindStringSubmatch(r.URL.Path) if len(matches) != 2 { - return errHTTPInternalErrorInvalidFilePath + return errHTTPInternalErrorInvalidPath } messageID := matches[1] file := filepath.Join(s.config.AttachmentCacheDir, messageID) @@ -815,7 +815,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, if contentLengthStr != "" { // Early "do-not-trust" check, hard limit see below contentLength, err := strconv.ParseInt(contentLengthStr, 10, 64) if err == nil && (contentLength > stats.AttachmentTotalSizeRemaining || contentLength > stats.AttachmentFileSizeLimit) { - return errHTTPEntityTooLargeAttachmentTooLarge + return errHTTPEntityTooLargeAttachment } } if m.Attachment == nil { @@ -839,7 +839,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, } m.Attachment.Size, err = s.fileCache.Write(m.ID, body, limiters...) if err == util.ErrLimitReached { - return errHTTPEntityTooLargeAttachmentTooLarge + return errHTTPEntityTooLargeAttachment } else if err != nil { return err } @@ -1426,15 +1426,10 @@ func (s *Server) ensureUser(next handleFunc) handleFunc { // before passing it on to the next handler. This is meant to be used in combination with handlePublish. func (s *Server) transformBodyJSON(next handleFunc) handleFunc { return func(w http.ResponseWriter, r *http.Request, v *visitor) error { - body, err := util.Peek(r.Body, s.config.MessageLimit) + m, err := readJSONWithLimit[publishMessage](r.Body, s.config.MessageLimit) if err != nil { return err } - defer r.Body.Close() - var m publishMessage - if err := json.NewDecoder(body).Decode(&m); err != nil { - return errHTTPBadRequestJSONInvalid - } if !topicRegex.MatchString(m.Topic) { return errHTTPBadRequestTopicInvalid } @@ -1467,7 +1462,7 @@ func (s *Server) transformBodyJSON(next handleFunc) handleFunc { if len(m.Actions) > 0 { actionsStr, err := json.Marshal(m.Actions) if err != nil { - return errHTTPBadRequestJSONInvalid + return errHTTPBadRequestMessageJSONInvalid } r.Header.Set("X-Actions", string(actionsStr)) } @@ -1535,7 +1530,9 @@ func (s *Server) visitor(r *http.Request) (v *visitor, err error) { } else { v = s.visitorFromIP(ip) } - v.user = u // Update user -- FIXME race? + v.mu.Lock() + v.user = u + v.mu.Unlock() return v, err // Always return visitor, even when error occurs! } diff --git a/server/server_account.go b/server/server_account.go index 6006f08c..42133868 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -2,7 +2,6 @@ package server import ( "encoding/json" - "errors" "heckel.io/ntfy/user" "heckel.io/ntfy/util" "net/http" @@ -21,7 +20,7 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v * return errHTTPUnauthorized // Cannot create account from user context } } - newAccount, err := util.ReadJSONWithLimit[apiAccountCreateRequest](r.Body, jsonBodyBytesLimit) + newAccount, err := readJSONWithLimit[apiAccountCreateRequest](r.Body, jsonBodyBytesLimit) if err != nil { return err } @@ -118,7 +117,7 @@ func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v * } func (s *Server) handleAccountPasswordChange(w http.ResponseWriter, r *http.Request, v *visitor) error { - newPassword, err := util.ReadJSONWithLimit[apiAccountPasswordChangeRequest](r.Body, jsonBodyBytesLimit) + newPassword, err := readJSONWithLimit[apiAccountPasswordChangeRequest](r.Body, jsonBodyBytesLimit) if err != nil { return err } @@ -174,7 +173,7 @@ func (s *Server) handleAccountTokenExtend(w http.ResponseWriter, r *http.Request func (s *Server) handleAccountTokenDelete(w http.ResponseWriter, r *http.Request, v *visitor) error { // TODO rate limit if v.user.Token == "" { - return errHTTPUnauthorized + return errHTTPBadRequestNoTokenProvided } if err := s.userManager.RemoveToken(v.user); err != nil { return err @@ -184,7 +183,7 @@ func (s *Server) handleAccountTokenDelete(w http.ResponseWriter, r *http.Request } func (s *Server) handleAccountSettingsChange(w http.ResponseWriter, r *http.Request, v *visitor) error { - newPrefs, err := util.ReadJSONWithLimit[user.Prefs](r.Body, jsonBodyBytesLimit) + newPrefs, err := readJSONWithLimit[user.Prefs](r.Body, jsonBodyBytesLimit) if err != nil { return err } @@ -218,7 +217,7 @@ func (s *Server) handleAccountSettingsChange(w http.ResponseWriter, r *http.Requ } func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Request, v *visitor) error { - newSubscription, err := util.ReadJSONWithLimit[user.Subscription](r.Body, jsonBodyBytesLimit) + newSubscription, err := readJSONWithLimit[user.Subscription](r.Body, jsonBodyBytesLimit) if err != nil { return err } @@ -250,13 +249,13 @@ func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Req func (s *Server) handleAccountSubscriptionChange(w http.ResponseWriter, r *http.Request, v *visitor) error { matches := accountSubscriptionSingleRegex.FindStringSubmatch(r.URL.Path) if len(matches) != 2 { - return errHTTPInternalErrorInvalidFilePath // FIXME + return errHTTPInternalErrorInvalidPath } - updatedSubscription, err := util.ReadJSONWithLimit[user.Subscription](r.Body, jsonBodyBytesLimit) + subscriptionID := matches[1] + updatedSubscription, err := readJSONWithLimit[user.Subscription](r.Body, jsonBodyBytesLimit) if err != nil { return err } - subscriptionID := matches[1] if v.user.Prefs == nil || v.user.Prefs.Subscriptions == nil { return errHTTPNotFound } @@ -283,14 +282,9 @@ func (s *Server) handleAccountSubscriptionChange(w http.ResponseWriter, r *http. } func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error { - if v.user == nil { - return errors.New("no user") - } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this matches := accountSubscriptionSingleRegex.FindStringSubmatch(r.URL.Path) if len(matches) != 2 { - return errHTTPInternalErrorInvalidFilePath // FIXME + return errHTTPInternalErrorInvalidPath } subscriptionID := matches[1] if v.user.Prefs == nil || v.user.Prefs.Subscriptions == nil { @@ -308,5 +302,7 @@ func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http. return err } } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this return nil } diff --git a/server/server_account_test.go b/server/server_account_test.go index 7b059eed..d4411418 100644 --- a/server/server_account_test.go +++ b/server/server_account_test.go @@ -62,6 +62,25 @@ func TestAccount_Signup_LimitReached(t *testing.T) { require.Equal(t, 42906, toHTTPError(t, rr.Body.String()).Code) } +func TestAccount_Signup_AsUser(t *testing.T) { + conf := newTestConfigWithUsers(t) + conf.EnableSignup = true + s := newTestServer(t, conf) + + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin)) + require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser)) + + rr := request(t, s, "POST", "/v1/account", `{"username":"emma", "password":"emma"}`, map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, rr.Code) + + rr = request(t, s, "POST", "/v1/account", `{"username":"marian", "password":"marian"}`, map[string]string{ + "Authorization": util.BasicAuth("ben", "ben"), + }) + require.Equal(t, 401, rr.Code) +} + func TestAccount_Signup_Disabled(t *testing.T) { conf := newTestConfigWithUsers(t) conf.EnableSignup = false @@ -112,6 +131,144 @@ func TestAccount_Get_Anonymous(t *testing.T) { require.Equal(t, int64(23), account.Stats.EmailsRemaining) } +func TestAccount_ChangeSettings(t *testing.T) { + s := newTestServer(t, newTestConfigWithUsers(t)) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) + user, _ := s.userManager.User("phil") + token, _ := s.userManager.CreateToken(user) + + rr := request(t, s, "PATCH", "/v1/account/settings", `{"notification": {"sound": "juntos"},"ignored": true}`, map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, rr.Code) + + rr = request(t, s, "PATCH", "/v1/account/settings", `{"notification": {"delete_after": 86400}, "language": "de"}`, map[string]string{ + "Authorization": util.BearerAuth(token.Value), + }) + require.Equal(t, 200, rr.Code) + + rr = request(t, s, "GET", "/v1/account", `{"username":"marian", "password":"marian"}`, map[string]string{ + "Authorization": util.BearerAuth(token.Value), + }) + require.Equal(t, 200, rr.Code) + account, _ := util.ReadJSON[apiAccountResponse](io.NopCloser(rr.Body)) + require.Equal(t, "de", account.Language) + require.Equal(t, 86400, account.Notification.DeleteAfter) + require.Equal(t, "juntos", account.Notification.Sound) + require.Equal(t, 0, account.Notification.MinPriority) // Not set +} + +func TestAccount_Subscription_AddUpdateDelete(t *testing.T) { + s := newTestServer(t, newTestConfigWithUsers(t)) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) + + rr := request(t, s, "POST", "/v1/account/subscription", `{"base_url": "http://abc.com", "topic": "def"}`, map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, rr.Code) + + rr = request(t, s, "GET", "/v1/account", "", map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, rr.Code) + account, _ := util.ReadJSON[apiAccountResponse](io.NopCloser(rr.Body)) + require.Equal(t, 1, len(account.Subscriptions)) + require.NotEmpty(t, account.Subscriptions[0].ID) + require.Equal(t, "http://abc.com", account.Subscriptions[0].BaseURL) + require.Equal(t, "def", account.Subscriptions[0].Topic) + require.Equal(t, "", account.Subscriptions[0].DisplayName) + + subscriptionID := account.Subscriptions[0].ID + rr = request(t, s, "PATCH", "/v1/account/subscription/"+subscriptionID, `{"display_name": "ding dong"}`, map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, rr.Code) + + rr = request(t, s, "GET", "/v1/account", "", map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, rr.Code) + account, _ = util.ReadJSON[apiAccountResponse](io.NopCloser(rr.Body)) + require.Equal(t, 1, len(account.Subscriptions)) + require.Equal(t, subscriptionID, account.Subscriptions[0].ID) + require.Equal(t, "http://abc.com", account.Subscriptions[0].BaseURL) + require.Equal(t, "def", account.Subscriptions[0].Topic) + require.Equal(t, "ding dong", account.Subscriptions[0].DisplayName) + + rr = request(t, s, "DELETE", "/v1/account/subscription/"+subscriptionID, "", map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, rr.Code) + + rr = request(t, s, "GET", "/v1/account", "", map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, rr.Code) + account, _ = util.ReadJSON[apiAccountResponse](io.NopCloser(rr.Body)) + require.Equal(t, 0, len(account.Subscriptions)) +} + +func TestAccount_ChangePassword(t *testing.T) { + s := newTestServer(t, newTestConfigWithUsers(t)) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) + + rr := request(t, s, "POST", "/v1/account/password", `{"password": "new password"}`, map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, rr.Code) + + rr = request(t, s, "GET", "/v1/account", "", map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 401, rr.Code) + + rr = request(t, s, "GET", "/v1/account", "", map[string]string{ + "Authorization": util.BasicAuth("phil", "new password"), + }) + require.Equal(t, 200, rr.Code) +} + +func TestAccount_ChangePassword_NoAccount(t *testing.T) { + s := newTestServer(t, newTestConfigWithUsers(t)) + + rr := request(t, s, "POST", "/v1/account/password", `{"password": "new password"}`, nil) + require.Equal(t, 401, rr.Code) +} + +func TestAccount_ExtendToken(t *testing.T) { + s := newTestServer(t, newTestConfigWithUsers(t)) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) + + rr := request(t, s, "POST", "/v1/account/token", "", map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, rr.Code) + token, err := util.ReadJSON[apiAccountTokenResponse](io.NopCloser(rr.Body)) + require.Nil(t, err) + + time.Sleep(time.Second) + + rr = request(t, s, "PATCH", "/v1/account/token", "", map[string]string{ + "Authorization": util.BearerAuth(token.Token), + }) + require.Equal(t, 200, rr.Code) + extendedToken, err := util.ReadJSON[apiAccountTokenResponse](io.NopCloser(rr.Body)) + require.Nil(t, err) + require.Equal(t, token.Token, extendedToken.Token) + require.True(t, token.Expires < extendedToken.Expires) +} + +func TestAccount_ExtendToken_NoTokenProvided(t *testing.T) { + s := newTestServer(t, newTestConfigWithUsers(t)) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) + + rr := request(t, s, "PATCH", "/v1/account/token", "", map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), // Not Bearer! + }) + require.Equal(t, 400, rr.Code) + require.Equal(t, 40023, toHTTPError(t, rr.Body.String()).Code) +} + func TestAccount_Delete_Success(t *testing.T) { conf := newTestConfigWithUsers(t) conf.EnableSignup = true diff --git a/server/server_matrix.go b/server/server_matrix.go index 5c985f04..99d8dc34 100644 --- a/server/server_matrix.go +++ b/server/server_matrix.go @@ -113,7 +113,7 @@ func newRequestFromMatrixJSON(r *http.Request, baseURL string, messageLimit int) } defer r.Body.Close() if body.LimitReached { - return nil, errHTTPEntityTooLargeMatrixRequestTooLarge + return nil, errHTTPEntityTooLargeMatrixRequest } var m matrixRequest if err := json.Unmarshal(body.PeekedBytes, &m); err != nil { diff --git a/server/server_matrix_test.go b/server/server_matrix_test.go index ad94da5d..77d6d4ac 100644 --- a/server/server_matrix_test.go +++ b/server/server_matrix_test.go @@ -29,7 +29,7 @@ func TestMatrix_NewRequestFromMatrixJSON_TooLarge(t *testing.T) { body := `{"notification":{"content":{"body":"I'm floating in a most peculiar way.","msgtype":"m.text"},"counts":{"missed_calls":1,"unread":2},"devices":[{"app_id":"org.matrix.matrixConsole.ios","data":{},"pushkey":"https://ntfy.sh/upABCDEFGHI?up=1","pushkey_ts":12345678,"tweaks":{"sound":"bing"}}],"event_id":"$3957tyerfgewrf384","prio":"high","room_alias":"#exampleroom:matrix.org","room_id":"!slw48wfj34rtnrf:example.com","room_name":"Mission Control","sender":"@exampleuser:matrix.org","sender_display_name":"Major Tom","type":"m.room.message"}}` r, _ := http.NewRequest("POST", "http://ntfy.example.com/_matrix/push/v1/notify", strings.NewReader(body)) _, err := newRequestFromMatrixJSON(r, baseURL, maxLength) - require.Equal(t, errHTTPEntityTooLargeMatrixRequestTooLarge, err) + require.Equal(t, errHTTPEntityTooLargeMatrixRequest, err) } func TestMatrix_NewRequestFromMatrixJSON_InvalidJSON(t *testing.T) { diff --git a/server/util.go b/server/util.go index b2c7b2e6..1fbe2843 100644 --- a/server/util.go +++ b/server/util.go @@ -5,6 +5,7 @@ import ( "github.com/emersion/go-smtp" "heckel.io/ntfy/log" "heckel.io/ntfy/util" + "io" "net/http" "net/netip" "strings" @@ -121,3 +122,15 @@ func extractIPAddress(r *http.Request, behindProxy bool) netip.Addr { } return ip } + +func readJSONWithLimit[T any](r io.ReadCloser, limit int) (*T, error) { + obj, err := util.ReadJSONWithLimit[T](r, limit) + if err == util.ErrInvalidJSON { + return nil, errHTTPBadRequestJSONInvalid + } else if err == util.ErrTooLargeJSON { + return nil, errHTTPEntityTooLargeJSONBody + } else if err != nil { + return nil, err + } + return obj, nil +} diff --git a/util/util.go b/util/util.go index c9535e46..abc0df9e 100644 --- a/util/util.go +++ b/util/util.go @@ -31,6 +31,11 @@ var ( noQuotesRegex = regexp.MustCompile(`^[-_./:@a-zA-Z0-9]+$`) ) +var ( + ErrInvalidJSON = errors.New("invalid JSON") + ErrTooLargeJSON = errors.New("too large JSON") +) + // FileExists checks if a file exists, and returns true if it does func FileExists(filename string) bool { stat, _ := os.Stat(filename) @@ -293,21 +298,23 @@ func QuoteCommand(command []string) string { func ReadJSON[T any](body io.ReadCloser) (*T, error) { var obj T if err := json.NewDecoder(body).Decode(&obj); err != nil { - return nil, err + return nil, ErrInvalidJSON } return &obj, nil } // ReadJSONWithLimit reads the given io.ReadCloser into a struct, but only until limit is reached func ReadJSONWithLimit[T any](r io.ReadCloser, limit int) (*T, error) { - r, err := Peek(r, limit) + defer r.Close() + p, err := Peek(r, limit) if err != nil { return nil, err + } else if p.LimitReached { + return nil, ErrTooLargeJSON } - defer r.Close() var obj T - if err := json.NewDecoder(r).Decode(&obj); err != nil { - return nil, err + if err := json.NewDecoder(p).Decode(&obj); err != nil { + return nil, ErrInvalidJSON } return &obj, nil }