From 1abcc88fce53c041767f025fa89734ab5ed685ed Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Fri, 9 Jun 2023 23:17:48 -0400 Subject: [PATCH] Add subscription_topic table, change updated_at type to INT, split expire function --- server/config.go | 4 +- server/server_manager.go | 2 +- server/server_web_push.go | 25 +++-- server/server_web_push_test.go | 11 ++- server/types.go | 1 + server/webpush_store.go | 169 ++++++++++++++++++++------------- 6 files changed, 131 insertions(+), 81 deletions(-) diff --git a/server/config.go b/server/config.go index 3d779fba..3bdda835 100644 --- a/server/config.go +++ b/server/config.go @@ -23,10 +23,10 @@ const ( DefaultStripePriceCacheDuration = 3 * time.Hour // Time to keep Stripe prices cached in memory before a refresh is needed ) -// Defines default web push settings +// Defines default Web Push settings const ( DefaultWebPushExpiryWarningDuration = 7 * 24 * time.Hour - DefaultWebPushExpiryDuration = DefaultWebPushExpiryWarningDuration + 24*time.Hour + DefaultWebPushExpiryDuration = 9 * 24 * time.Hour ) // Defines all global and per-visitor limits diff --git a/server/server_manager.go b/server/server_manager.go index b065aff1..66d449de 100644 --- a/server/server_manager.go +++ b/server/server_manager.go @@ -15,7 +15,7 @@ func (s *Server) execManager() { s.pruneTokens() s.pruneAttachments() s.pruneMessages() - s.pruneOrNotifyWebPushSubscriptions() + s.pruneAndNotifyWebPushSubscriptions() // Message count per topic var messagesCached int diff --git a/server/server_web_push.go b/server/server_web_push.go index 0b9ac808..30a2cd02 100644 --- a/server/server_web_push.go +++ b/server/server_web_push.go @@ -50,7 +50,7 @@ func (s *Server) handleWebPushUpdate(w http.ResponseWriter, r *http.Request, v * } } } - if err := s.webPush.UpsertSubscription(req.Endpoint, req.Topics, v.MaybeUserID(), req.Auth, req.P256dh); err != nil { + if err := s.webPush.UpsertSubscription(req.Endpoint, req.Auth, req.P256dh, v.MaybeUserID(), req.Topics); err != nil { return err } return s.writeJSON(w, newSuccessResponse()) @@ -75,21 +75,25 @@ func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) { } } -func (s *Server) pruneOrNotifyWebPushSubscriptions() { +func (s *Server) pruneAndNotifyWebPushSubscriptions() { if s.config.WebPushPublicKey == "" { return } go func() { - if err := s.pruneOrNotifyWebPushSubscriptionsInternal(); err != nil { + if err := s.pruneAndNotifyWebPushSubscriptionsInternal(); err != nil { log.Tag(tagWebPush).Err(err).Warn("Unable to prune or notify web push subscriptions") } }() } -func (s *Server) pruneOrNotifyWebPushSubscriptionsInternal() error { - subscriptions, err := s.webPush.ExpireAndGetExpiringSubscriptions(s.config.WebPushExpiryWarningDuration, s.config.WebPushExpiryDuration) +func (s *Server) pruneAndNotifyWebPushSubscriptionsInternal() error { + // Expire old subscriptions + if err := s.webPush.RemoveExpiredSubscriptions(s.config.WebPushExpiryDuration); err != nil { + return err + } + // Notify subscriptions that will expire soon + subscriptions, err := s.webPush.SubscriptionsExpiring(s.config.WebPushExpiryWarningDuration) if err != nil { - log.Tag(tagWebPush).Err(err).Warn("Unable to publish expiry imminent warning") return err } else if len(subscriptions) == 0 { return nil @@ -99,14 +103,19 @@ func (s *Server) pruneOrNotifyWebPushSubscriptionsInternal() error { log.Tag(tagWebPush).Err(err).Warn("Unable to marshal expiring payload") return err } + warningSent := make([]*webPushSubscription, 0) for _, subscription := range subscriptions { ctx := log.Context{"endpoint": subscription.Endpoint} if err := s.sendWebPushNotification(payload, subscription, &ctx); err != nil { log.Tag(tagWebPush).Err(err).Fields(ctx).Warn("Unable to publish expiry imminent warning") - return err + continue } + warningSent = append(warningSent, subscription) } - log.Tag(tagWebPush).Debug("Expiring old subscriptions and published %d expiry imminent warnings", len(subscriptions)) + if err := s.webPush.MarkExpiryWarningSent(warningSent); err != nil { + return err + } + log.Tag(tagWebPush).Debug("Expired old subscriptions and published %d expiry imminent warnings", len(subscriptions)) return nil } diff --git a/server/server_web_push_test.go b/server/server_web_push_test.go index c60ceaad..82ad7215 100644 --- a/server/server_web_push_test.go +++ b/server/server_web_push_test.go @@ -12,6 +12,7 @@ import ( "strings" "sync/atomic" "testing" + "time" ) const ( @@ -190,20 +191,20 @@ func TestServer_WebPush_Expiry(t *testing.T) { addSubscription(t, s, pushService.URL+"/push-receive", "test-topic") requireSubscriptionCount(t, s, "test-topic", 1) - _, err := s.webPush.db.Exec("UPDATE subscriptions SET updated_at = datetime('now', '-7 days')") + _, err := s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-7*24*time.Hour).Unix()) require.Nil(t, err) - s.pruneOrNotifyWebPushSubscriptions() + s.pruneAndNotifyWebPushSubscriptions() requireSubscriptionCount(t, s, "test-topic", 1) waitFor(t, func() bool { return received.Load() }) - _, err = s.webPush.db.Exec("UPDATE subscriptions SET updated_at = datetime('now', '-8 days')") + _, err = s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-9*24*time.Hour).Unix()) require.Nil(t, err) - s.pruneOrNotifyWebPushSubscriptions() + s.pruneAndNotifyWebPushSubscriptions() waitFor(t, func() bool { subs, err := s.webPush.SubscriptionsForTopic("test-topic") require.Nil(t, err) @@ -224,7 +225,7 @@ func payloadForTopics(t *testing.T, topics []string, endpoint string) string { } func addSubscription(t *testing.T, s *Server, endpoint string, topics ...string) { - require.Nil(t, s.webPush.UpsertSubscription(endpoint, topics, "", "kSC3T8aN1JCQxxPdrFLrZg", "BMKKbxdUU_xLS7G1Wh5AN8PvWOjCzkCuKZYb8apcqYrDxjOF_2piggBnoJLQYx9IeSD70fNuwawI3e9Y8m3S3PE")) // Test auth and p256dh + require.Nil(t, s.webPush.UpsertSubscription(endpoint, "kSC3T8aN1JCQxxPdrFLrZg", "BMKKbxdUU_xLS7G1Wh5AN8PvWOjCzkCuKZYb8apcqYrDxjOF_2piggBnoJLQYx9IeSD70fNuwawI3e9Y8m3S3PE", "u_123", topics)) // Test auth and p256dh } func requireSubscriptionCount(t *testing.T, s *Server, topic string, expectedLength int) { diff --git a/server/types.go b/server/types.go index 99f1c4f7..90995878 100644 --- a/server/types.go +++ b/server/types.go @@ -505,6 +505,7 @@ func newWebPushSubscriptionExpiringPayload() webPushControlMessagePayload { } type webPushSubscription struct { + ID string Endpoint string Auth string P256dh string diff --git a/server/webpush_store.go b/server/webpush_store.go index 774772be..6dc1ddef 100644 --- a/server/webpush_store.go +++ b/server/webpush_store.go @@ -2,47 +2,68 @@ package server import ( "database/sql" - "fmt" + "heckel.io/ntfy/util" "time" _ "github.com/mattn/go-sqlite3" // SQLite driver ) +const ( + subscriptionIDPrefix = "wps_" + subscriptionIDLength = 10 +) + const ( createWebPushSubscriptionsTableQuery = ` BEGIN; - CREATE TABLE IF NOT EXISTS subscriptions ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - topic TEXT NOT NULL, - user_id TEXT, + CREATE TABLE IF NOT EXISTS subscription ( + id TEXT PRIMARY KEY, endpoint TEXT NOT NULL, key_auth TEXT NOT NULL, key_p256dh TEXT NOT NULL, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, - warning_sent BOOLEAN DEFAULT FALSE + user_id TEXT NOT NULL, + updated_at INT NOT NULL, + warned_at INT NOT NULL DEFAULT 0 ); + CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint); + CREATE TABLE IF NOT EXISTS subscription_topic ( + subscription_id TEXT NOT NULL, + topic TEXT NOT NULL, + PRIMARY KEY (subscription_id, topic), + FOREIGN KEY (subscription_id) REFERENCES subscription (id) ON DELETE CASCADE + ); + CREATE INDEX IF NOT EXISTS idx_topic ON subscription_topic (topic); CREATE TABLE IF NOT EXISTS schemaVersion ( id INT PRIMARY KEY, version INT NOT NULL - ); - CREATE INDEX IF NOT EXISTS idx_topic ON subscriptions (topic); - CREATE INDEX IF NOT EXISTS idx_endpoint ON subscriptions (endpoint); - CREATE UNIQUE INDEX IF NOT EXISTS idx_topic_endpoint ON subscriptions (topic, endpoint); + ); COMMIT; ` - - insertWebPushSubscriptionQuery = ` - INSERT OR REPLACE INTO subscriptions (topic, user_id, endpoint, key_auth, key_p256dh) - VALUES (?, ?, ?, ?, ?) + builtinStartupQueries = ` + PRAGMA foreign_keys = ON; ` - deleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscriptions WHERE endpoint = ?` - deleteWebPushSubscriptionByUserIDQuery = `DELETE FROM subscriptions WHERE user_id = ?` - deleteWebPushSubscriptionsByAgeQuery = `DELETE FROM subscriptions WHERE warning_sent = 1 AND updated_at <= datetime('now', ?)` - selectWebPushSubscriptionsForTopicQuery = `SELECT endpoint, key_auth, key_p256dh, user_id FROM subscriptions WHERE topic = ?` - selectWebPushSubscriptionsExpiringSoonQuery = `SELECT DISTINCT endpoint, key_auth, key_p256dh, user_id FROM subscriptions WHERE warning_sent = 0 AND updated_at <= datetime('now', ?)` + selectWebPushSubscriptionIDByEndpoint = `SELECT id FROM subscription WHERE endpoint = ?` + selectWebPushSubscriptionsForTopicQuery = ` + SELECT id, endpoint, key_auth, key_p256dh, user_id + FROM subscription_topic st + JOIN subscription s ON s.id = st.subscription_id + WHERE st.topic = ? + ` + selectWebPushSubscriptionsExpiringSoonQuery = `SELECT id, endpoint, key_auth, key_p256dh, user_id FROM subscription WHERE warned_at = 0 AND updated_at <= ?` + insertWebPushSubscriptionQuery = ` + INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, updated_at, warned_at) + VALUES (?, ?, ?, ?, ?, ?, ?) + ON CONFLICT (endpoint) + DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, updated_at = excluded.updated_at, warned_at = excluded.warned_at + ` + updateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?` + deleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?` + deleteWebPushSubscriptionByUserIDQuery = `DELETE FROM subscription WHERE user_id = ?` + deleteWebPushSubscriptionByAgeQuery = `DELETE FROM subscription WHERE updated_at <= ?` // Full table scan! - updateWarningSentQuery = `UPDATE subscriptions SET warning_sent = true WHERE warning_sent = 0 AND updated_at <= datetime('now', ?)` + insertWebPushSubscriptionTopicQuery = `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)` + deleteWebPushSubscriptionTopicAllQuery = `DELETE FROM subscription_topic WHERE subscription_id = ?` ) // Schema management queries @@ -64,6 +85,9 @@ func newWebPushStore(filename string) (*webPushStore, error) { if err := setupWebPushDB(db); err != nil { return nil, err } + if err := runWebPushStartupQueries(db); err != nil { + return nil, err + } return &webPushStore{ db: db, }, nil @@ -88,19 +112,47 @@ func setupNewWebPushDB(db *sql.DB) error { return nil } +func runWebPushStartupQueries(db *sql.DB) error { + _, err := db.Exec(builtinStartupQueries) + return err +} + // UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. It always first deletes all // existing entries for a given endpoint. -func (c *webPushStore) UpsertSubscription(endpoint string, topics []string, userID, auth, p256dh string) error { +func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, topics []string) error { tx, err := c.db.Begin() if err != nil { return err } defer tx.Rollback() - if _, err := tx.Exec(deleteWebPushSubscriptionByEndpointQuery, endpoint); err != nil { + // Read existing subscription ID for endpoint (or create new ID) + rows, err := tx.Query(selectWebPushSubscriptionIDByEndpoint, endpoint) + if err != nil { + return err + } + defer rows.Close() + var subscriptionID string + if rows.Next() { + if err := rows.Scan(&subscriptionID); err != nil { + return err + } + } else { + subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength) + } + if err := rows.Close(); err != nil { + return err + } + // Insert or update subscription + updatedAt, warnedAt := time.Now().Unix(), 0 + if _, err = tx.Exec(insertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, updatedAt, warnedAt); err != nil { + return err + } + // Replace all subscription topics + if _, err := tx.Exec(deleteWebPushSubscriptionTopicAllQuery, subscriptionID); err != nil { return err } for _, topic := range topics { - if _, err = tx.Exec(insertWebPushSubscriptionQuery, topic, userID, endpoint, auth, p256dh); err != nil { + if _, err = tx.Exec(insertWebPushSubscriptionTopicQuery, subscriptionID, topic); err != nil { return err } } @@ -113,65 +165,47 @@ func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscripti return nil, err } defer rows.Close() - subscriptions := make([]*webPushSubscription, 0) - for rows.Next() { - var endpoint, auth, p256dh, userID string - if err = rows.Scan(&endpoint, &auth, &p256dh, &userID); err != nil { - return nil, err - } - subscriptions = append(subscriptions, &webPushSubscription{ - Endpoint: endpoint, - Auth: auth, - P256dh: p256dh, - UserID: userID, - }) - } - return subscriptions, nil + return c.subscriptionsFromRows(rows) } -func (c *webPushStore) ExpireAndGetExpiringSubscriptions(warningDuration time.Duration, expiryDuration time.Duration) ([]*webPushSubscription, error) { - // TODO this should be two functions - tx, err := c.db.Begin() - if err != nil { - return nil, err - } - defer tx.Rollback() - - _, err = tx.Exec(deleteWebPushSubscriptionsByAgeQuery, fmt.Sprintf("-%.2f seconds", expiryDuration.Seconds())) - if err != nil { - return nil, err - } - - rows, err := tx.Query(selectWebPushSubscriptionsExpiringSoonQuery, fmt.Sprintf("-%.2f seconds", warningDuration.Seconds())) +func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error) { + rows, err := c.db.Query(selectWebPushSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix()) if err != nil { return nil, err } defer rows.Close() + return c.subscriptionsFromRows(rows) +} +func (c *webPushStore) MarkExpiryWarningSent(subscriptions []*webPushSubscription) error { + tx, err := c.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + for _, subscription := range subscriptions { + if _, err := tx.Exec(updateWebPushSubscriptionWarningSentQuery, time.Now().Unix(), subscription.ID); err != nil { + return err + } + } + return tx.Commit() +} + +func (c *webPushStore) subscriptionsFromRows(rows *sql.Rows) ([]*webPushSubscription, error) { subscriptions := make([]*webPushSubscription, 0) for rows.Next() { - var endpoint, auth, p256dh, userID string - if err = rows.Scan(&endpoint, &auth, &p256dh, &userID); err != nil { + var id, endpoint, auth, p256dh, userID string + if err := rows.Scan(&id, &endpoint, &auth, &p256dh, &userID); err != nil { return nil, err } subscriptions = append(subscriptions, &webPushSubscription{ + ID: id, Endpoint: endpoint, Auth: auth, P256dh: p256dh, UserID: userID, }) } - - // also set warning as sent - _, err = tx.Exec(updateWarningSentQuery, fmt.Sprintf("-%.2f seconds", warningDuration.Seconds())) - if err != nil { - return nil, err - } - - if err = tx.Commit(); err != nil { - return nil, err - } - return subscriptions, nil } @@ -185,6 +219,11 @@ func (c *webPushStore) RemoveSubscriptionsByUserID(userID string) error { return err } +func (c *webPushStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error { + _, err := c.db.Exec(deleteWebPushSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix()) + return err +} + func (c *webPushStore) Close() error { return c.db.Close() }