1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2025-06-15 09:03:20 +02:00

A little polishing, make upgrade banner work when not logged in

This commit is contained in:
binwiederhier 2023-01-18 13:46:40 -05:00
parent 7cff44b647
commit f945fb4cdd
15 changed files with 98 additions and 121 deletions

View file

@ -115,7 +115,6 @@ type Config struct {
EnableWeb bool
EnableSignup bool // Enable creation of accounts via API and UI
EnableLogin bool
EnablePayments bool
EnableReservations bool // Allow users with role "user" to own/reserve topics
Version string // injected by App
}

View file

@ -59,7 +59,8 @@ var (
errHTTPBadRequestPermissionInvalid = &errHTTP{40025, http.StatusBadRequest, "invalid request: incorrect permission string", ""}
errHTTPBadRequestMakesNoSenseForAdmin = &errHTTP{40026, http.StatusBadRequest, "invalid request: this makes no sense for admins", ""}
errHTTPBadRequestNotAPaidUser = &errHTTP{40027, http.StatusBadRequest, "invalid request: not a paid user", ""}
errHTTPBadRequestInvalidStripeRequest = &errHTTP{40028, http.StatusBadRequest, "invalid request: not a valid Stripe request", ""}
errHTTPBadRequestBillingRequestInvalid = &errHTTP{40028, http.StatusBadRequest, "invalid request: not a valid billing request", ""}
errHTTPBadRequestBillingSubscriptionExists = &errHTTP{40029, http.StatusBadRequest, "invalid request: billing subscription already exists", ""}
errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""}
errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication"}
errHTTPForbidden = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication"}

View file

@ -89,12 +89,7 @@ const (
WHERE time <= ? AND published = 0
ORDER BY time, id
`
selectMessagesExpiredQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
FROM messages
WHERE expires <= ? AND published = 1
ORDER BY time, id
`
selectMessagesExpiredQuery = `SELECT mid FROM messages WHERE expires <= ? AND published = 1`
updateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE mid = ?`
selectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
selectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic`
@ -431,12 +426,25 @@ func (c *messageCache) MessagesDue() ([]*message, error) {
return readMessages(rows)
}
func (c *messageCache) MessagesExpired() ([]*message, error) {
// MessagesExpired returns a list of IDs for messages that have expires (should be deleted)
func (c *messageCache) MessagesExpired() ([]string, error) {
rows, err := c.db.Query(selectMessagesExpiredQuery, time.Now().Unix())
if err != nil {
return nil, err
}
return readMessages(rows)
defer rows.Close()
ids := make([]string, 0)
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, err
}
ids = append(ids, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
return ids, nil
}
func (c *messageCache) MarkPublished(m *message) error {

View file

@ -270,13 +270,9 @@ func testCachePrune(t *testing.T, c *messageCache) {
require.Equal(t, 2, counts["mytopic"])
require.Equal(t, 1, counts["another_topic"])
expiredMessages, err := c.MessagesExpired()
expiredMessageIDs, err := c.MessagesExpired()
require.Nil(t, err)
ids := make([]string, 0)
for _, m := range expiredMessages {
ids = append(ids, m.ID)
}
require.Nil(t, c.DeleteMessages(ids...))
require.Nil(t, c.DeleteMessages(expiredMessageIDs...))
counts, err = c.MessageCounts()
require.Nil(t, err)

View file

@ -43,10 +43,13 @@ import (
- delete subscription when account deleted
- delete messages + reserved topics on ResetTier
- move v1/account/tiers to v1/tiers
Limits & rate limiting:
users without tier: should the stats be persisted? are they meaningful?
-> test that the visitor is based on the IP address!
login/account endpoints
when ResetStats() is run, reset messagesLimiter (and others)?
update last_seen when API is accessed
Make sure account endpoints make sense for admins
@ -54,11 +57,10 @@ import (
- flicker of upgrade banner
- JS constants
Sync:
- "mute" setting
- figure out what settings are "web" or "phone"
- sync problems with "deleteAfter=0" and "displayName="
Delete visitor when tier is changed to refresh rate limiters
Tests:
- Payment endpoints (make mocks)
- Change tier from higher to lower tier (delete reservations)
- Message rate limiting and reset tests
- test that the visitor is based on the IP address when a user has no tier
@ -104,13 +106,13 @@ var (
accountPath = "/account"
matrixPushPath = "/_matrix/push/v1/notify"
apiHealthPath = "/v1/health"
apiTiers = "/v1/tiers"
apiAccountPath = "/v1/account"
apiAccountTokenPath = "/v1/account/token"
apiAccountPasswordPath = "/v1/account/password"
apiAccountSettingsPath = "/v1/account/settings"
apiAccountSubscriptionPath = "/v1/account/subscription"
apiAccountReservationPath = "/v1/account/reservation"
apiAccountBillingTiersPath = "/v1/account/billing/tiers"
apiAccountBillingPortalPath = "/v1/account/billing/portal"
apiAccountBillingWebhookPath = "/v1/account/billing/webhook"
apiAccountBillingSubscriptionPath = "/v1/account/billing/subscription"
@ -378,20 +380,20 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
return s.ensureUser(s.withAccountSync(s.handleAccountReservationAdd))(w, r, v)
} else if r.Method == http.MethodDelete && apiAccountReservationSingleRegex.MatchString(r.URL.Path) {
return s.ensureUser(s.withAccountSync(s.handleAccountReservationDelete))(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == apiAccountBillingTiersPath {
return s.ensurePaymentsEnabled(s.handleAccountBillingTiersGet)(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingSubscriptionPath {
return s.ensurePaymentsEnabled(s.ensureUser(s.handleAccountBillingSubscriptionCreate))(w, r, v) // Account sync via incoming Stripe webhook
} else if r.Method == http.MethodGet && apiAccountBillingSubscriptionCheckoutSuccessRegex.MatchString(r.URL.Path) {
return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingSubscriptionCreateSuccess))(w, r, v) // No user context!
} else if r.Method == http.MethodPut && r.URL.Path == apiAccountBillingSubscriptionPath {
return s.ensurePaymentsEnabled(s.ensureUser(s.handleAccountBillingSubscriptionUpdate))(w, r, v) // Account sync via incoming Stripe webhook
return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingSubscriptionUpdate))(w, r, v) // Account sync via incoming Stripe webhook
} else if r.Method == http.MethodDelete && r.URL.Path == apiAccountBillingSubscriptionPath {
return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingSubscriptionDelete))(w, r, v) // Account sync via incoming Stripe webhook
} else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingPortalPath {
return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingPortalSessionCreate))(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingWebhookPath {
return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingWebhook))(w, r, v)
return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingWebhook))(w, r, v) // This request comes from Stripe!
} else if r.Method == http.MethodGet && r.URL.Path == apiTiers {
return s.ensurePaymentsEnabled(s.handleBillingTiersGet)(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath {
return s.handleMatrixDiscovery(w)
} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
@ -480,7 +482,7 @@ func (s *Server) handleWebConfig(w http.ResponseWriter, _ *http.Request, _ *visi
AppRoot: appRoot,
EnableLogin: s.config.EnableLogin,
EnableSignup: s.config.EnableSignup,
EnablePayments: s.config.EnablePayments,
EnablePayments: s.config.StripeSecretKey != "",
EnableReservations: s.config.EnableReservations,
DisallowedTopics: disallowedTopics,
}
@ -1271,18 +1273,14 @@ func (s *Server) execManager() {
// DeleteMessages message cache
log.Debug("Manager: Pruning messages")
expiredMessages, err := s.messageCache.MessagesExpired()
expiredMessageIDs, err := s.messageCache.MessagesExpired()
if err != nil {
log.Warn("Manager: Error retrieving expired messages: %s", err.Error())
} else if len(expiredMessages) > 0 {
ids := make([]string, 0)
for _, m := range expiredMessages {
ids = append(ids, m.ID)
}
if err := s.fileCache.Remove(ids...); err != nil {
} else if len(expiredMessageIDs) > 0 {
if err := s.fileCache.Remove(expiredMessageIDs...); err != nil {
log.Warn("Manager: Error deleting attachments for expired messages: %s", err.Error())
}
if err := s.messageCache.DeleteMessages(ids...); err != nil {
if err := s.messageCache.DeleteMessages(expiredMessageIDs...); err != nil {
log.Warn("Manager: Error marking attachments deleted: %s", err.Error())
}
} else {
@ -1359,6 +1357,8 @@ func (s *Server) runManager() {
}
}
// runStatsResetter runs once a day (usually midnight UTC) to reset all the visitor's message and
// email counters. The stats are used to display the counters in the web app, as well as for rate limiting.
func (s *Server) runStatsResetter() {
for {
runAt := util.NextOccurrenceUTC(s.config.VisitorStatsResetTime, time.Now())

View file

@ -33,7 +33,7 @@ func (s *Server) ensureUser(next handleFunc) handleFunc {
func (s *Server) ensurePaymentsEnabled(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if !s.config.EnablePayments {
if s.config.StripeSecretKey == "" {
return errHTTPNotFound
}
return next(w, r, v)

View file

@ -25,11 +25,15 @@ const (
)
var (
errNotAPaidTier = errors.New("tier does not have Stripe price identifier")
errNotAPaidTier = errors.New("tier does not have billing price identifier")
errMultipleBillingSubscriptions = errors.New("cannot have multiple billing subscriptions")
errNoBillingSubscription = errors.New("user does not have an active billing subscription")
)
func (s *Server) handleAccountBillingTiersGet(w http.ResponseWriter, r *http.Request, v *visitor) error {
tiers, err := v.userManager.Tiers()
// handleBillingTiersGet returns all available paid tiers, and the free tier. This is to populate the upgrade dialog
// in the UI. Note that this endpoint does NOT have a user context (no v.user!).
func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
tiers, err := s.userManager.Tiers()
if err != nil {
return err
}
@ -92,7 +96,7 @@ func (s *Server) handleAccountBillingTiersGet(w http.ResponseWriter, r *http.Req
// will be updated by a subsequent webhook from Stripe, once the subscription becomes active.
func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user.Billing.StripeSubscriptionID != "" {
return errors.New("subscription already exists") //FIXME
return errHTTPBadRequestBillingSubscriptionExists
}
req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit)
if err != nil {
@ -112,7 +116,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
if err != nil {
return err
} else if stripeCustomer.Subscriptions != nil && len(stripeCustomer.Subscriptions.Data) > 0 {
return errors.New("customer cannot have more than one subscription") //FIXME
return errMultipleBillingSubscriptions
}
}
successURL := s.config.BaseURL + apiAccountBillingSubscriptionCheckoutSuccessTemplate
@ -157,15 +161,15 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
sess, err := session.Get(sessionID, nil) // FIXME how do I rate limit this?
if err != nil {
log.Warn("Stripe: %s", err)
return errHTTPBadRequestInvalidStripeRequest
return errHTTPBadRequestBillingRequestInvalid
} else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" {
return wrapErrHTTP(errHTTPBadRequestInvalidStripeRequest, "customer or subscription not found")
return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "customer or subscription not found")
}
sub, err := subscription.Get(sess.Subscription.ID, nil)
if err != nil {
return err
} else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil {
return wrapErrHTTP(errHTTPBadRequestInvalidStripeRequest, "more than one line item in existing subscription")
return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "more than one line item in existing subscription")
}
tier, err := s.userManager.TierByStripePrice(sub.Items.Data[0].Price.ID)
if err != nil {
@ -186,7 +190,7 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
// a user's tier accordingly. This endpoint only works if there is an existing subscription.
func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user.Billing.StripeSubscriptionID == "" {
return errors.New("no existing subscription for user")
return errNoBillingSubscription
}
req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit)
if err != nil {
@ -226,9 +230,6 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
// and cancelling the Stripe subscription entirely
func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user.Billing.StripeCustomerID == "" {
return errHTTPBadRequestNotAPaidUser
}
if v.user.Billing.StripeSubscriptionID != "" {
params := &stripe.SubscriptionParams{
CancelAtPeriodEnd: stripe.Bool(true),
@ -269,11 +270,13 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter,
return nil
}
// handleAccountBillingWebhook handles incoming Stripe webhooks. It mainly keeps the local user database in sync
// with the Stripe view of the world. This endpoint is authorized via the Stripe webhook secret. Note that the
// visitor (v) in this endpoint is the Stripe API, so we don't have v.user available.
func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Request, _ *visitor) error {
// Note that the visitor (v) in this endpoint is the Stripe API, so we don't have v.user available
stripeSignature := r.Header.Get("Stripe-Signature")
if stripeSignature == "" {
return errHTTPBadRequestInvalidStripeRequest
return errHTTPBadRequestBillingRequestInvalid
}
body, err := util.Peek(r.Body, stripeBodyBytesLimit)
if err != nil {
@ -283,9 +286,9 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ
}
event, err := webhook.ConstructEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey)
if err != nil {
return errHTTPBadRequestInvalidStripeRequest
return errHTTPBadRequestBillingRequestInvalid
} else if event.Data == nil || event.Data.Raw == nil {
return errHTTPBadRequestInvalidStripeRequest
return errHTTPBadRequestBillingRequestInvalid
}
log.Info("Stripe: webhook event %s received", event.Type)
switch event.Type {
@ -306,7 +309,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
cancelAt := gjson.GetBytes(event, "cancel_at")
priceID := gjson.GetBytes(event, "items.data.0.price.id")
if !subscriptionID.Exists() || !status.Exists() || !currentPeriodEnd.Exists() || !cancelAt.Exists() || !priceID.Exists() {
return errHTTPBadRequestInvalidStripeRequest
return errHTTPBadRequestBillingRequestInvalid
}
log.Info("Stripe: customer %s: Updating subscription to status %s, with price %s", customerID.String(), status, priceID)
u, err := s.userManager.UserByStripeCustomer(customerID.String())
@ -327,7 +330,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error {
customerID := gjson.GetBytes(event, "customer")
if !customerID.Exists() {
return errHTTPBadRequestInvalidStripeRequest
return errHTTPBadRequestBillingRequestInvalid
}
log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", customerID.String())
u, err := s.userManager.UserByStripeCustomer(customerID.String())

View file

@ -745,7 +745,7 @@ func TestServer_Auth_ViaQuery(t *testing.T) {
func TestServer_StatsResetter(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.AuthDefault = user.PermissionDenyAll
c.VisitorStatsResetTime = time.Now().Add(time.Second)
c.VisitorStatsResetTime = time.Now().Add(2 * time.Second)
s := newTestServer(t, c)
go s.runStatsResetter()
@ -773,8 +773,8 @@ func TestServer_StatsResetter(t *testing.T) {
require.Nil(t, err)
require.Equal(t, int64(5), account.Stats.Messages)
// Start stats resetter
time.Sleep(1200 * time.Millisecond)
// Wait for stats resetter to run
time.Sleep(2200 * time.Millisecond)
// User stats show 0 messages now!
response = request(t, s, "GET", "/v1/account", "", nil)
@ -1325,7 +1325,7 @@ func TestServer_PublishAttachmentTooLargeBodyVisitorAttachmentTotalSizeLimit(t *
require.Equal(t, 41301, err.Code)
}
func TestServer_PublishAttachmentAndPrune(t *testing.T) {
func TestServer_PublishAttachmentAndExpire(t *testing.T) {
content := util.RandomString(5000) // > 4096
c := newTestConfig(t)

View file

@ -208,6 +208,7 @@ func (v *visitor) ResetStats() {
if v.user != nil {
v.user.Stats.Messages = 0
v.user.Stats.Emails = 0
// v.messagesLimiter = ... // FIXME
}
}