diff --git a/server/server.go b/server/server.go index 5042e477..deb94b54 100644 --- a/server/server.go +++ b/server/server.go @@ -298,7 +298,7 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && r.URL.Path == "/" { return s.limitRequests(s.transformBodyJSON(s.authWrite(s.handlePublish)))(w, r, v) } else if r.Method == http.MethodPost && r.URL.Path == matrixPushPath { - return s.limitRequests(s.transformMatrixJSON(s.authWrite(s.handlePublish)))(w, r, v) + return s.limitRequests(s.transformMatrixJSON(s.authWrite(s.handlePublishMatrix)))(w, r, v) } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicPathRegex.MatchString(r.URL.Path) { return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v) } else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) { @@ -428,25 +428,25 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) return nil } -func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error { +func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) { t, err := s.topicFromPath(r.URL.Path) if err != nil { - return err + return nil, err } body, err := util.Peek(r.Body, s.config.MessageLimit) if err != nil { - return err + return nil, err } m := newDefaultMessage(t.ID, "") cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, v, m) if err != nil { - return err + return nil, err } if m.PollID != "" { m = newPollRequestMessage(t.ID, m.PollID) } if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil { - return err + return nil, err } if m.Message == "" { m.Message = emptyMessageBody @@ -459,7 +459,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito } if !delayed { if err := t.Publish(v, m); err != nil { - return err + return nil, err } if s.firebaseClient != nil && firebase { go s.sendToFirebase(v, m) @@ -475,17 +475,44 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito } if cache { if err := s.messageCache.AddMessage(m); err != nil { - return err + return nil, err } } + s.mu.Lock() + s.messages++ + s.mu.Unlock() + return m, nil +} + +func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error { + m, err := s.handlePublishWithoutResponse(r, v) + if err != nil { + return err + } w.Header().Set("Content-Type", "application/json") w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests if err := json.NewEncoder(w).Encode(m); err != nil { return err } - s.mu.Lock() - s.messages++ - s.mu.Unlock() + return nil +} + +func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *visitor) error { + pushKey := r.Header.Get("X-Matrix-Pushkey") + if pushKey == "" { + return errHTTPBadRequestMatrixMessageInvalid + } + response := &matrixResponse{ + Rejected: make([]string, 0), + } + _, err := s.handlePublishWithoutResponse(r, v) + if err != nil { + response.Rejected = append(response.Rejected, pushKey) + } + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(response); err != nil { + return err + } return nil } @@ -1301,6 +1328,10 @@ type matrixDevice struct { PushKey string `json:"pushkey"` } +type matrixResponse struct { + Rejected []string `json:"rejected"` +} + func (s *Server) transformMatrixJSON(next handleFunc) handleFunc { return func(w http.ResponseWriter, r *http.Request, v *visitor) error { if s.config.BaseURL == "" { @@ -1314,26 +1345,40 @@ func (s *Server) transformMatrixJSON(next handleFunc) handleFunc { var m matrixMessage if err := json.NewDecoder(body).Decode(&m); err != nil { return errHTTPBadRequestMatrixMessageInvalid - } else if m.Notification == nil || len(m.Notification.Devices) == 0 { - return errHTTPBadRequestMatrixMessageInvalid - } else if !strings.HasPrefix(m.Notification.Devices[0].PushKey, s.config.BaseURL+"/") { + } else if m.Notification == nil || len(m.Notification.Devices) == 0 || m.Notification.Devices[0].PushKey == "" { return errHTTPBadRequestMatrixMessageInvalid } - u, err := url.Parse(m.Notification.Devices[0].PushKey) + pushKey := m.Notification.Devices[0].PushKey + if !strings.HasPrefix(pushKey, s.config.BaseURL+"/") { + return matrixError(w, pushKey, errHTTPBadRequestMatrixMessageInvalid) + } + u, err := url.Parse(pushKey) if err != nil { - return errHTTPBadRequestMatrixMessageInvalid + return matrixError(w, pushKey, errHTTPBadRequestMatrixMessageInvalid) } r.URL.Path = u.Path r.URL.RawQuery = u.RawQuery r.RequestURI = u.RequestURI() r.Body = io.NopCloser(bytes.NewReader(body.PeekedBytes)) + r.Header.Set("X-Matrix-Pushkey", pushKey) if err := next(w, r, v); err != nil { - return nil + return matrixError(w, pushKey, errHTTPBadRequestMatrixMessageInvalid) } return nil } } +func matrixError(w http.ResponseWriter, pushKey string, err error) error { + log.Debug("Matrix message with push key %s rejected: %s", pushKey, err.Error()) + response := &matrixResponse{ + Rejected: []string{pushKey}, + } + if err := json.NewEncoder(w).Encode(response); err != nil { + return err + } + return nil +} + func (s *Server) authWrite(next handleFunc) handleFunc { return s.withAuth(next, auth.PermissionWrite) }