From ebbc2838ba5472ee10fe0c4bba09494bc08c814d Mon Sep 17 00:00:00 2001
From: Philipp Heckel <pheckel@datto.com>
Date: Wed, 15 Jun 2022 20:36:49 -0400
Subject: [PATCH] Move error handling to main error handling; move matrix logic
 to its own file

---
 server/errors.go        |  2 +-
 server/server.go        | 36 +++++---------------
 server/server_matrix.go | 73 +++++++++++++++++++++++++++++++++--------
 3 files changed, 68 insertions(+), 43 deletions(-)

diff --git a/server/errors.go b/server/errors.go
index 28aa4be6..5a62de2e 100644
--- a/server/errors.go
+++ b/server/errors.go
@@ -51,7 +51,7 @@ var (
 	errHTTPBadRequestJSONInvalid                     = &errHTTP{40017, http.StatusBadRequest, "invalid request: request body must be message JSON", "https://ntfy.sh/docs/publish/#publish-as-json"}
 	errHTTPBadRequestActionsInvalid                  = &errHTTP{40018, http.StatusBadRequest, "invalid request: actions invalid", "https://ntfy.sh/docs/publish/#action-buttons"}
 	errHTTPBadRequestMatrixMessageInvalid            = &errHTTP{40019, http.StatusBadRequest, "invalid request: Matrix JSON invalid", "https://ntfy.sh/docs/publish/#matrix-gateway"}
-	errHTTPBadRequestMatrixPushkeyBaseURLMismatch    = &errHTTP{40020, http.StatusBadRequest, "invalid request: Push key must be prefixed with base URL", "https://ntfy.sh/docs/publish/#matrix-gateway"}
+	errHTTPBadRequestMatrixPushkeyBaseURLMismatch    = &errHTTP{40020, http.StatusBadRequest, "invalid request: push key must be prefixed with base URL", "https://ntfy.sh/docs/publish/#matrix-gateway"}
 	errHTTPNotFound                                  = &errHTTP{40401, http.StatusNotFound, "page not found", ""}
 	errHTTPUnauthorized                              = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication"}
 	errHTTPForbidden                                 = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication"}
diff --git a/server/server.go b/server/server.go
index 78d7fd9b..2aa114a2 100644
--- a/server/server.go
+++ b/server/server.go
@@ -259,6 +259,10 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
 			}
 			return // Do not attempt to write to upgraded connection
 		}
+		if matrixErr, ok := err.(*errMatrix); ok {
+			writeMatrixError(w, r, v, matrixErr)
+			return
+		}
 		httpErr, ok := err.(*errHTTP)
 		if !ok {
 			httpErr = errHTTPInternalError
@@ -506,8 +510,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
 func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *visitor) error {
 	_, err := s.handlePublishWithoutResponse(r, v)
 	if err != nil {
-		pushKey := r.Header.Get(matrixPushkeyHeader)
-		return writeMatrixError(w, pushKey, err)
+		return &errMatrix{pushKey: r.Header.Get(matrixPushKeyHeader), err: err}
 	}
 	return writeMatrixSuccess(w)
 }
@@ -1314,35 +1317,12 @@ func (s *Server) transformBodyJSON(next handleFunc) handleFunc {
 
 func (s *Server) transformMatrixJSON(next handleFunc) handleFunc {
 	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
-		if s.config.BaseURL == "" {
-			return errHTTPInternalErrorMissingBaseURL
-		}
-		body, err := util.Peek(r.Body, s.config.MessageLimit)
+		newRequest, err := newRequestFromMatrixJSON(r, s.config.BaseURL, s.config.MessageLimit)
 		if err != nil {
 			return err
 		}
-		defer r.Body.Close()
-		var m matrixMessage
-		if err := json.NewDecoder(body).Decode(&m); err != nil {
-			return errHTTPBadRequestMatrixMessageInvalid
-		} else if m.Notification == nil || len(m.Notification.Devices) == 0 || m.Notification.Devices[0].PushKey == "" {
-			return errHTTPBadRequestMatrixMessageInvalid
-		}
-		pushKey := m.Notification.Devices[0].PushKey
-		if !strings.HasPrefix(pushKey, s.config.BaseURL+"/") {
-			return writeMatrixError(w, pushKey, errHTTPBadRequestMatrixPushkeyBaseURLMismatch)
-		}
-		u, err := url.Parse(pushKey)
-		if err != nil {
-			return writeMatrixError(w, pushKey, errHTTPBadRequestMatrixMessageInvalid)
-		}
-		r.URL.Path = u.Path
-		r.URL.RawQuery = u.RawQuery
-		r.RequestURI = u.RequestURI()
-		r.Body = io.NopCloser(bytes.NewReader(body.PeekedBytes))
-		r.Header.Set(matrixPushkeyHeader, pushKey)
-		if err := next(w, r, v); err != nil {
-			return writeMatrixError(w, pushKey, errHTTPBadRequestMatrixMessageInvalid)
+		if err := next(w, newRequest, v); err != nil {
+			return &errMatrix{pushKey: newRequest.Header.Get(matrixPushKeyHeader), err: err}
 		}
 		return nil
 	}
diff --git a/server/server_matrix.go b/server/server_matrix.go
index c8b3eca4..c20aff46 100644
--- a/server/server_matrix.go
+++ b/server/server_matrix.go
@@ -1,14 +1,18 @@
 package server
 
 import (
+	"bytes"
 	"encoding/json"
+	"fmt"
 	"heckel.io/ntfy/log"
+	"heckel.io/ntfy/util"
 	"io"
 	"net/http"
+	"strings"
 )
 
 const (
-	matrixPushkeyHeader = "X-Matrix-Pushkey"
+	matrixPushKeyHeader = "X-Matrix-Pushkey"
 )
 
 type matrixMessage struct {
@@ -27,16 +31,67 @@ type matrixResponse struct {
 	Rejected []string `json:"rejected"`
 }
 
+type errMatrix struct {
+	pushKey string
+	err     error
+}
+
+func (e errMatrix) Error() string {
+	if e.err != nil {
+		return fmt.Sprintf("message with push key %s rejected: %s", e.pushKey, e.err.Error())
+	}
+	return fmt.Sprintf("message with push key %s rejected", e.pushKey)
+}
+
+func newRequestFromMatrixJSON(r *http.Request, baseURL string, messageLimit int) (*http.Request, error) {
+	if baseURL == "" {
+		return nil, errHTTPInternalErrorMissingBaseURL
+	}
+	body, err := util.Peek(r.Body, messageLimit)
+	if err != nil {
+		return nil, err
+	}
+	defer r.Body.Close()
+	var m matrixMessage
+	if err := json.NewDecoder(body).Decode(&m); err != nil {
+		return nil, errHTTPBadRequestMatrixMessageInvalid
+	} else if m.Notification == nil || len(m.Notification.Devices) == 0 || m.Notification.Devices[0].PushKey == "" {
+		return nil, errHTTPBadRequestMatrixMessageInvalid
+	}
+	pushKey := m.Notification.Devices[0].PushKey
+	if !strings.HasPrefix(pushKey, baseURL+"/") {
+		return nil, &errMatrix{pushKey: pushKey, err: errHTTPBadRequestMatrixPushkeyBaseURLMismatch}
+	}
+	newRequest, err := http.NewRequest(http.MethodPost, pushKey, io.NopCloser(bytes.NewReader(body.PeekedBytes)))
+	if err != nil {
+		return nil, &errMatrix{pushKey: pushKey, err: err}
+	}
+	newRequest.Header.Set(matrixPushKeyHeader, pushKey)
+	return newRequest, nil
+}
+
 func handleMatrixDiscovery(w http.ResponseWriter) error {
 	w.Header().Set("Content-Type", "application/json")
 	_, err := io.WriteString(w, `{"unifiedpush":{"gateway":"matrix"}}`+"\n")
 	return err
 }
 
-func writeMatrixError(w http.ResponseWriter, pushKey string, err error) error {
-	log.Debug("Matrix message with push key %s rejected: %s", pushKey, err.Error())
+func writeMatrixError(w http.ResponseWriter, r *http.Request, v *visitor, err *errMatrix) error {
+	log.Debug("%s Matrix gateway error: %s", logHTTPPrefix(v, r), err.Error())
+	return writeMatrixResponse(w, err.pushKey)
+}
+
+func writeMatrixSuccess(w http.ResponseWriter) error {
+	return writeMatrixResponse(w, "")
+}
+
+func writeMatrixResponse(w http.ResponseWriter, rejectedPushKey string) error {
+	rejected := make([]string, 0)
+	if rejectedPushKey != "" {
+		rejected = append(rejected, rejectedPushKey)
+	}
 	response := &matrixResponse{
-		Rejected: []string{pushKey},
+		Rejected: rejected,
 	}
 	w.Header().Set("Content-Type", "application/json")
 	if err := json.NewEncoder(w).Encode(response); err != nil {
@@ -44,13 +99,3 @@ func writeMatrixError(w http.ResponseWriter, pushKey string, err error) error {
 	}
 	return nil
 }
-
-func writeMatrixSuccess(w http.ResponseWriter) error {
-	response := &matrixResponse{
-		Rejected: make([]string, 0),
-	}
-	if err := json.NewEncoder(w).Encode(response); err != nil {
-		return err
-	}
-	return nil
-}