1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2025-09-01 17:44:51 +02:00

Add web push tests

This commit is contained in:
nimbleghost 2023-05-29 17:57:21 +02:00
parent ff5c854192
commit a9fef387fa
20 changed files with 372 additions and 41 deletions

View file

@ -233,8 +233,10 @@ func NewConfig() *Config {
EnableReservations: false,
AccessControlAllowOrigin: "*",
Version: "",
WebPushEnabled: false,
WebPushPrivateKey: "",
WebPushPublicKey: "",
WebPushSubscriptionsFile: "",
WebPushEmailAddress: "",
}
}

View file

@ -77,7 +77,7 @@ var (
rawPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/raw$`)
wsPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/ws$`)
authPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/auth$`)
webPushPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/web-push$`)
webPushSubscribePathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/web-push/subscribe$`)
webPushUnsubscribePathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/web-push/unsubscribe$`)
publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/(publish|send|trigger)$`)
@ -535,7 +535,7 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeWS))(w, r, v)
} else if r.Method == http.MethodGet && authPathRegex.MatchString(r.URL.Path) {
return s.limitRequests(s.authorizeTopicRead(s.handleTopicAuth))(w, r, v)
} else if r.Method == http.MethodPost && webPushPathRegex.MatchString(r.URL.Path) {
} else if r.Method == http.MethodPost && webPushSubscribePathRegex.MatchString(r.URL.Path) {
return s.limitRequestsWithTopic(s.authorizeTopicRead(s.ensureWebPushEnabled(s.handleTopicWebPushSubscribe)))(w, r, v)
} else if r.Method == http.MethodPost && webPushUnsubscribePathRegex.MatchString(r.URL.Path) {
return s.limitRequestsWithTopic(s.authorizeTopicRead(s.ensureWebPushEnabled(s.handleTopicWebPushUnsubscribe)))(w, r, v)
@ -985,7 +985,6 @@ func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {
return
}
failedCount := 0
totalCount := len(subscriptions)
wg := &sync.WaitGroup{}
@ -1029,12 +1028,11 @@ func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {
jsonPayload, err := json.Marshal(payload)
if err != nil {
failedCount++
logvm(v, m).Err(err).Fields(ctx).Debug("Unable to publish web push message")
return
}
_, err = webpush.SendNotification(jsonPayload, &sub.BrowserSubscription, &webpush.Options{
resp, err := webpush.SendNotification(jsonPayload, &sub.BrowserSubscription, &webpush.Options{
Subscriber: s.config.WebPushEmailAddress,
VAPIDPublicKey: s.config.WebPushPublicKey,
VAPIDPrivateKey: s.config.WebPushPrivateKey,
@ -1044,26 +1042,29 @@ func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {
})
if err != nil {
failedCount++
logvm(v, m).Err(err).Fields(ctx).Debug("Unable to publish web push message")
// probably need to handle different codes differently,
// but for now just expire the subscription on any error
err = s.webPushSubscriptionStore.ExpireWebPushEndpoint(sub.BrowserSubscription.Endpoint)
if err != nil {
logvm(v, m).Err(err).Fields(ctx).Warn("Unable to expire subscription")
}
return
}
// May want to handle at least 429 differently, but for now treat all errors the same
if !(200 <= resp.StatusCode && resp.StatusCode <= 299) {
logvm(v, m).Fields(ctx).Field("response", resp).Debug("Unable to publish web push message")
err = s.webPushSubscriptionStore.ExpireWebPushEndpoint(sub.BrowserSubscription.Endpoint)
if err != nil {
logvm(v, m).Err(err).Fields(ctx).Warn("Unable to expire subscription")
}
return
}
}(i, xi)
}
ctx = log.Context{"topic": m.Topic, "message_id": m.ID, "failed_count": failedCount, "total_count": totalCount}
if failedCount > 0 {
logvm(v, m).Fields(ctx).Warn("Unable to publish web push messages to %d of %d endpoints", failedCount, totalCount)
} else {
logvm(v, m).Fields(ctx).Debug("Published %d web push messages successfully", totalCount)
}
}
func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, firebase bool, email, call string, unifiedpush bool, err *errHTTP) {

View file

@ -40,7 +40,7 @@
# Enable web push
#
# Run ntfy web-push-keys to generate the keys
# Run ntfy web-push generate-keys to generate the keys
#
# web-push-enabled: true
# web-push-public-key: ""

View file

@ -22,6 +22,7 @@ import (
"testing"
"time"
"github.com/SherClockHolmes/webpush-go"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/log"
"heckel.io/ntfy/util"
@ -2604,14 +2605,35 @@ func newTestConfig(t *testing.T) *Config {
return conf
}
func newTestConfigWithAuthFile(t *testing.T) *Config {
conf := newTestConfig(t)
func configureAuth(t *testing.T, conf *Config) *Config {
conf.AuthFile = filepath.Join(t.TempDir(), "user.db")
conf.AuthStartupQueries = "pragma journal_mode = WAL; pragma synchronous = normal; pragma temp_store = memory;"
conf.AuthBcryptCost = bcrypt.MinCost // This speeds up tests a lot
return conf
}
func newTestConfigWithAuthFile(t *testing.T) *Config {
conf := newTestConfig(t)
conf = configureAuth(t, conf)
return conf
}
func newTestConfigWithWebPush(t *testing.T) *Config {
conf := newTestConfig(t)
privateKey, publicKey, err := webpush.GenerateVAPIDKeys()
if err != nil {
t.Fatal(err)
}
conf.WebPushEnabled = true
conf.WebPushSubscriptionsFile = filepath.Join(t.TempDir(), "subscriptions.db")
conf.WebPushEmailAddress = "testing@example.com"
conf.WebPushPrivateKey = privateKey
conf.WebPushPublicKey = publicKey
return conf
}
func newTestServer(t *testing.T, config *Config) *Server {
server, err := New(config)
if err != nil {

View file

@ -0,0 +1,212 @@
package server
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"github.com/SherClockHolmes/webpush-go"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
)
var (
webPushSubscribePayloadExample = `{
"browser_subscription":{
"endpoint": "https://example.com/webpush",
"keys": {
"p256dh": "p256dh-key",
"auth": "auth-key"
}
}
}`
)
func TestServer_WebPush_GetConfig(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
response := request(t, s, "GET", "/v1/web-push-config", "", nil)
require.Equal(t, 200, response.Code)
require.Equal(t, fmt.Sprintf(`{"public_key":"%s"}`, s.config.WebPushPublicKey)+"\n", response.Body.String())
}
func TestServer_WebPush_TopicSubscribe(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
response := request(t, s, "POST", "/test-topic/web-push/subscribe", webPushSubscribePayloadExample, nil)
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
subs, err := s.webPushSubscriptionStore.GetSubscriptionsForTopic("test-topic")
if err != nil {
t.Fatal(err)
}
require.Len(t, subs, 1)
require.Equal(t, subs[0].BrowserSubscription.Endpoint, "https://example.com/webpush")
require.Equal(t, subs[0].BrowserSubscription.Keys.P256dh, "p256dh-key")
require.Equal(t, subs[0].BrowserSubscription.Keys.Auth, "auth-key")
require.Equal(t, subs[0].Username, "")
}
func TestServer_WebPush_TopicSubscribeProtected_Allowed(t *testing.T) {
config := configureAuth(t, newTestConfigWithWebPush(t))
config.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, config)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
response := request(t, s, "POST", "/test-topic/web-push/subscribe", webPushSubscribePayloadExample, map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
subs, err := s.webPushSubscriptionStore.GetSubscriptionsForTopic("test-topic")
if err != nil {
t.Fatal(err)
}
require.Len(t, subs, 1)
require.Equal(t, subs[0].Username, "ben")
}
func TestServer_WebPush_TopicSubscribeProtected_Denied(t *testing.T) {
config := configureAuth(t, newTestConfigWithWebPush(t))
config.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, config)
response := request(t, s, "POST", "/test-topic/web-push/subscribe", webPushSubscribePayloadExample, nil)
require.Equal(t, 403, response.Code)
requireSubscriptionCount(t, s, "test-topic", 0)
}
func TestServer_WebPush_TopicUnsubscribe(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
response := request(t, s, "POST", "/test-topic/web-push/subscribe", webPushSubscribePayloadExample, nil)
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
requireSubscriptionCount(t, s, "test-topic", 1)
unsubscribe := `{"endpoint":"https://example.com/webpush"}`
response = request(t, s, "POST", "/test-topic/web-push/unsubscribe", unsubscribe, nil)
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
requireSubscriptionCount(t, s, "test-topic", 0)
}
func TestServer_WebPush_DeleteAccountUnsubscribe(t *testing.T) {
config := configureAuth(t, newTestConfigWithWebPush(t))
config.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, config)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
response := request(t, s, "POST", "/test-topic/web-push/subscribe", webPushSubscribePayloadExample, map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
requireSubscriptionCount(t, s, "test-topic", 1)
request(t, s, "DELETE", "/v1/account", `{"password":"ben"}`, map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
// should've been deleted with the account
requireSubscriptionCount(t, s, "test-topic", 0)
}
func TestServer_WebPush_Publish(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
var received atomic.Bool
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "/push-receive", r.URL.Path)
require.Equal(t, "high", r.Header.Get("Urgency"))
require.Equal(t, "", r.Header.Get("Topic"))
received.Store(true)
}))
defer upstreamServer.Close()
addSubscription(t, s, "test-topic", upstreamServer.URL+"/push-receive")
request(t, s, "PUT", "/test-topic", "web push test", nil)
waitFor(t, func() bool {
return received.Load()
})
}
func TestServer_WebPush_PublishExpire(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
var received atomic.Bool
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := io.ReadAll(r.Body)
require.Nil(t, err)
// Gone
w.WriteHeader(410)
w.Write([]byte(``))
received.Store(true)
}))
defer upstreamServer.Close()
addSubscription(t, s, "test-topic", upstreamServer.URL+"/push-receive")
addSubscription(t, s, "test-topic-abc", upstreamServer.URL+"/push-receive")
requireSubscriptionCount(t, s, "test-topic", 1)
requireSubscriptionCount(t, s, "test-topic-abc", 1)
request(t, s, "PUT", "/test-topic", "web push test", nil)
waitFor(t, func() bool {
return received.Load()
})
// Receiving the 410 should've caused the publisher to expire all subscriptions on the endpoint
requireSubscriptionCount(t, s, "test-topic", 0)
requireSubscriptionCount(t, s, "test-topic-abc", 0)
}
func addSubscription(t *testing.T, s *Server, topic string, url string) {
err := s.webPushSubscriptionStore.AddSubscription("test-topic", "", webPushSubscribePayload{
BrowserSubscription: webpush.Subscription{
Endpoint: url,
Keys: webpush.Keys{
// connected to a local test VAPID key, not a leak!
Auth: "kSC3T8aN1JCQxxPdrFLrZg",
P256dh: "BMKKbxdUU_xLS7G1Wh5AN8PvWOjCzkCuKZYb8apcqYrDxjOF_2piggBnoJLQYx9IeSD70fNuwawI3e9Y8m3S3PE",
},
},
})
if err != nil {
t.Fatal(err)
}
}
func requireSubscriptionCount(t *testing.T, s *Server, topic string, expectedLength int) {
subs, err := s.webPushSubscriptionStore.GetSubscriptionsForTopic("test-topic")
if err != nil {
t.Fatal(err)
}
require.Len(t, subs, expectedLength)
}