diff --git a/cmd/access.go b/cmd/access.go index 76c375bd..3b83000a 100644 --- a/cmd/access.go +++ b/cmd/access.go @@ -100,22 +100,24 @@ func changeAccess(c *cli.Context, manager *user.Manager, username string, topic if !util.Contains([]string{"", "read-write", "rw", "read-only", "read", "ro", "write-only", "write", "wo", "none", "deny"}, perms) { return errors.New("permission must be one of: read-write, read-only, write-only, or deny (or the aliases: read, ro, write, wo, none)") } - read := util.Contains([]string{"read-write", "rw", "read-only", "read", "ro"}, perms) - write := util.Contains([]string{"read-write", "rw", "write-only", "write", "wo"}, perms) + permission, err := user.ParsePermission(perms) + if err != nil { + return err + } u, err := manager.User(username) if err == user.ErrUserNotFound { return fmt.Errorf("user %s does not exist", username) } else if u.Role == user.RoleAdmin { return fmt.Errorf("user %s is an admin user, access control entries have no effect", username) } - if err := manager.AllowAccess("", username, topic, read, write); err != nil { + if err := manager.AllowAccess(username, topic, permission); err != nil { return err } - if read && write { + if permission.IsReadWrite() { fmt.Fprintf(c.App.ErrWriter, "granted read-write access to topic %s\n\n", topic) - } else if read { + } else if permission.IsRead() { fmt.Fprintf(c.App.ErrWriter, "granted read-only access to topic %s\n\n", topic) - } else if write { + } else if permission.IsWrite() { fmt.Fprintf(c.App.ErrWriter, "granted write-only access to topic %s\n\n", topic) } else { fmt.Fprintf(c.App.ErrWriter, "revoked all access to topic %s\n\n", topic) diff --git a/server/message_cache.go b/server/message_cache.go index e8564fc6..8788cf99 100644 --- a/server/message_cache.go +++ b/server/message_cache.go @@ -57,9 +57,10 @@ const ( INSERT INTO messages (mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, encoding, published) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ` - deleteMessageQuery = `DELETE FROM messages WHERE mid = ?` - selectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?` // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics - selectMessagesSinceTimeQuery = ` + deleteMessageQuery = `DELETE FROM messages WHERE mid = ?` + updateMessagesForTopicExpiryQuery = `UPDATE messages SET expires = ? WHERE topic = ?` + selectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?` // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics + selectMessagesSinceTimeQuery = ` SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding FROM messages WHERE topic = ? AND time >= ? AND published = 1 @@ -96,7 +97,7 @@ const ( selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic` updateAttachmentDeleted = `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?` - selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires <= ? AND attachment_deleted = 0` + selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0` selectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?` selectAttachmentsSizeByUserQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?` ) @@ -506,6 +507,20 @@ func (c *messageCache) DeleteMessages(ids ...string) error { return tx.Commit() } +func (c *messageCache) ExpireMessages(topics ...string) error { + tx, err := c.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + for _, t := range topics { + if _, err := tx.Exec(updateMessagesForTopicExpiryQuery, time.Now().Unix(), t); err != nil { + return err + } + } + return tx.Commit() +} + func (c *messageCache) AttachmentsExpired() ([]string, error) { rows, err := c.db.Query(selectAttachmentsExpiredQuery, time.Now().Unix()) if err != nil { diff --git a/server/message_cache_test.go b/server/message_cache_test.go index 064c4723..2b838f25 100644 --- a/server/message_cache_test.go +++ b/server/message_cache_test.go @@ -362,6 +362,61 @@ func testCacheAttachments(t *testing.T, c *messageCache) { require.Equal(t, int64(0), size) } +func TestSqliteCache_Attachments_Expired(t *testing.T) { + testCacheAttachmentsExpired(t, newSqliteTestCache(t)) +} + +func TestMemCache_Attachments_Expired(t *testing.T) { + testCacheAttachmentsExpired(t, newMemTestCache(t)) +} + +func testCacheAttachmentsExpired(t *testing.T, c *messageCache) { + m := newDefaultMessage("mytopic", "flower for you") + m.ID = "m1" + m.Expires = time.Now().Add(time.Hour).Unix() + require.Nil(t, c.AddMessage(m)) + + m = newDefaultMessage("mytopic", "message with attachment") + m.ID = "m2" + m.Expires = time.Now().Add(2 * time.Hour).Unix() + m.Attachment = &attachment{ + Name: "car.jpg", + Type: "image/jpeg", + Size: 10000, + Expires: time.Now().Add(2 * time.Hour).Unix(), + URL: "https://ntfy.sh/file/aCaRURL.jpg", + } + require.Nil(t, c.AddMessage(m)) + + m = newDefaultMessage("mytopic", "message with external attachment") + m.ID = "m3" + m.Expires = time.Now().Add(2 * time.Hour).Unix() + m.Attachment = &attachment{ + Name: "car.jpg", + Type: "image/jpeg", + Expires: 0, // Unknown! + URL: "https://somedomain.com/car.jpg", + } + require.Nil(t, c.AddMessage(m)) + + m = newDefaultMessage("mytopic2", "message with expired attachment") + m.ID = "m4" + m.Expires = time.Now().Add(2 * time.Hour).Unix() + m.Attachment = &attachment{ + Name: "expired-car.jpg", + Type: "image/jpeg", + Size: 20000, + Expires: time.Now().Add(-1 * time.Hour).Unix(), + URL: "https://ntfy.sh/file/aCaRURL.jpg", + } + require.Nil(t, c.AddMessage(m)) + + ids, err := c.AttachmentsExpired() + require.Nil(t, err) + require.Equal(t, 1, len(ids)) + require.Equal(t, "m4", ids[0]) +} + func TestSqliteCache_Migration_From0(t *testing.T) { filename := newSqliteTestCacheFile(t) db, err := sql.Open("sqlite3", filename) diff --git a/server/server.go b/server/server.go index 36bb9583..8d3526d3 100644 --- a/server/server.go +++ b/server/server.go @@ -40,13 +40,15 @@ import ( - v.user --> see publishSyncEventAsync() test payments: - - delete messages + reserved topics on ResetTier + - delete messages + reserved topics on ResetTier delete attachments in access.go + - reconciliation Limits & rate limiting: - users without tier: should the stats be persisted? are they meaningful? - -> test that the visitor is based on the IP address! + users without tier: should the stats be persisted? are they meaningful? -> test that the visitor is based on the IP address! login/account endpoints when ResetStats() is run, reset messagesLimiter (and others)? + Delete visitor when tier is changed to refresh rate limiters + Make sure account endpoints make sense for admins UI: @@ -55,10 +57,9 @@ import ( - JS constants Sync: - sync problems with "deleteAfter=0" and "displayName=" - Delete visitor when tier is changed to refresh rate limiters + Tests: - Payment endpoints (make mocks) - - Change tier from higher to lower tier (delete reservations) - Message rate limiting and reset tests - test that the visitor is based on the IP address when a user has no tier */ diff --git a/server/server_account.go b/server/server_account.go index 5bc36016..6bcb3233 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -119,7 +119,7 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis } func (s *Server) handleAccountDelete(w http.ResponseWriter, _ *http.Request, v *visitor) error { - if v.user.Billing.StripeCustomerID != "" { + if v.user.Billing.StripeSubscriptionID != "" { log.Info("Deleting user %s (billing customer: %s, billing subscription: %s)", v.user.Name, v.user.Billing.StripeCustomerID, v.user.Billing.StripeSubscriptionID) if v.user.Billing.StripeSubscriptionID != "" { if _, err := s.stripe.CancelSubscription(v.user.Billing.StripeSubscriptionID); err != nil { @@ -332,11 +332,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ return errHTTPTooManyRequestsLimitReservations } } - owner, username := v.user.Name, v.user.Name - if err := s.userManager.AllowAccess(owner, username, req.Topic, true, true); err != nil { - return err - } - if err := s.userManager.AllowAccess(owner, user.Everyone, req.Topic, everyone.IsRead(), everyone.IsWrite()); err != nil { + if err := s.userManager.ReserveAccess(v.user.Name, req.Topic, everyone); err != nil { return err } return s.writeJSON(w, newSuccessResponse()) @@ -357,10 +353,7 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R } else if !authorized { return errHTTPUnauthorized } - if err := s.userManager.ResetAccess(v.user.Name, topic); err != nil { - return err - } - if err := s.userManager.ResetAccess(user.Everyone, topic); err != nil { + if err := s.userManager.RemoveReservations(v.user.Name, topic); err != nil { return err } return s.writeJSON(w, newSuccessResponse()) diff --git a/server/server_payments.go b/server/server_payments.go index 45ef82e8..c7ece4ef 100644 --- a/server/server_payments.go +++ b/server/server_payments.go @@ -27,6 +27,28 @@ var ( errNoBillingSubscription = errors.New("user does not have an active billing subscription") ) +// Payments in ntfy are done via Stripe. +// +// Pretty much all payments related things are in this file. The following processes +// handle payments: +// +// - Checkout: +// Creating a Stripe customer and subscription via the Checkout flow. This flow is only used if the +// ntfy user is not already a Stripe customer. This requires redirecting to the Stripe checkout page. +// It is implemented in handleAccountBillingSubscriptionCreate and the success callback +// handleAccountBillingSubscriptionCreateSuccess. +// - Update subscription: +// Switching between Stripe subscriptions (upgrade/downgrade) is handled via +// handleAccountBillingSubscriptionUpdate. This also handles proration. +// - Cancel subscription (at period end): +// Users can cancel the Stripe subscription via the web app at the end of the billing period. This +// simply updates the subscription and Stripe will cancel it. Users cannot immediately cancel the +// subscription. +// - Webhooks: +// Whenever a subscription changes (updated, deleted), Stripe sends us a request via a webhook. +// This is used to keep the local user database fields up to date. Stripe is the source of truth. +// What Stripe says is mirrored and not questioned. + // handleBillingTiersGet returns all available paid tiers, and the free tier. This is to populate the upgrade dialog // in the UI. Note that this endpoint does NOT have a user context (no v.user!). func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ *visitor) error { @@ -37,7 +59,7 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ freeTier := defaultVisitorLimits(s.config) response := []*apiAccountBillingTier{ { - // Free tier: no code, name or price + // This is a bit of a hack: This is the "Free" tier. It has no tier code, name or price. Limits: &apiAccountLimits{ Messages: freeTier.MessagesLimit, MessagesExpiryDuration: int64(freeTier.MessagesExpiryDuration.Seconds()), @@ -130,6 +152,9 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r return s.writeJSON(w, response) } +// handleAccountBillingSubscriptionCreateSuccess is called after the Stripe checkout session has succeeded. We use +// the session ID in the URL to retrieve the Stripe subscription and update the local database. This is the first +// and only time we can map the local username with the Stripe customer ID. func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error { // We don't have a v.user in this endpoint, only a userManager! matches := apiAccountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path) @@ -139,8 +164,7 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr sessionID := matches[1] sess, err := s.stripe.GetSession(sessionID) // FIXME How do we rate limit this? if err != nil { - log.Warn("Stripe: %s", err) - return errHTTPBadRequestBillingRequestInvalid + return err } else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" { return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "customer or subscription not found") } @@ -158,7 +182,7 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr if err != nil { return err } - if err := s.updateSubscriptionAndTier(u, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt, tier.Code); err != nil { + if err := s.updateSubscriptionAndTier(u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil { return err } http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther) @@ -216,6 +240,8 @@ func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r return s.writeJSON(w, newSuccessResponse()) } +// handleAccountBillingPortalSessionCreate creates a session to the customer billing portal, and returns the +// redirect URL. The billing portal allows customers to change their payment methods, and cancel the subscription. func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error { if v.user.Billing.StripeCustomerID == "" { return errHTTPBadRequestNotAPaidUser @@ -250,10 +276,11 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ } event, err := s.stripe.ConstructWebhookEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey) if err != nil { - return errHTTPBadRequestBillingRequestInvalid + return err } else if event.Data == nil || event.Data.Raw == nil { return errHTTPBadRequestBillingRequestInvalid } + log.Info("Stripe: webhook event %s received", event.Type) switch event.Type { case "customer.subscription.updated": @@ -282,7 +309,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe if err != nil { return err } - if err := s.updateSubscriptionAndTier(u, r.Customer, subscriptionID, r.Status, r.CurrentPeriodEnd, r.CancelAt, tier.Code); err != nil { + if err := s.updateSubscriptionAndTier(u, tier, r.Customer, subscriptionID, r.Status, r.CurrentPeriodEnd, r.CancelAt); err != nil { return err } s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified())) @@ -301,29 +328,54 @@ func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMe if err != nil { return err } - if err := s.updateSubscriptionAndTier(u, r.Customer, "", "", 0, 0, ""); err != nil { + if err := s.updateSubscriptionAndTier(u, nil, r.Customer, "", "", 0, 0); err != nil { return err } s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified())) return nil } -func (s *Server) updateSubscriptionAndTier(u *user.User, customerID, subscriptionID, status string, paidUntil, cancelAt int64, tier string) error { - u.Billing.StripeCustomerID = customerID - u.Billing.StripeSubscriptionID = subscriptionID - u.Billing.StripeSubscriptionStatus = stripe.SubscriptionStatus(status) - u.Billing.StripeSubscriptionPaidUntil = time.Unix(paidUntil, 0) - u.Billing.StripeSubscriptionCancelAt = time.Unix(cancelAt, 0) - if tier == "" { +func (s *Server) updateSubscriptionAndTier(u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error { + // Remove excess reservations (if too many for tier), and mark associated messages deleted + reservations, err := s.userManager.Reservations(u.Name) + if err != nil { + return err + } + reservationsLimit := visitorDefaultReservationsLimit + if tier != nil { + reservationsLimit = tier.ReservationsLimit + } + if int64(len(reservations)) > reservationsLimit { + topics := make([]string, 0) + for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- { + topics = append(topics, reservations[i].Topic) + } + if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil { + return err + } + if err := s.messageCache.ExpireMessages(topics...); err != nil { + return err + } + } + // Change or remove tier + if tier == nil { if err := s.userManager.ResetTier(u.Name); err != nil { return err } } else { - if err := s.userManager.ChangeTier(u.Name, tier); err != nil { + if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil { return err } } - if err := s.userManager.ChangeBilling(u); err != nil { + // Update billing fields + billing := &user.Billing{ + StripeCustomerID: customerID, + StripeSubscriptionID: subscriptionID, + StripeSubscriptionStatus: stripe.SubscriptionStatus(status), + StripeSubscriptionPaidUntil: time.Unix(paidUntil, 0), + StripeSubscriptionCancelAt: time.Unix(cancelAt, 0), + } + if err := s.userManager.ChangeBilling(u.Name, billing); err != nil { return err } return nil diff --git a/server/server_payments_test.go b/server/server_payments_test.go index 2f2c60f0..634109cb 100644 --- a/server/server_payments_test.go +++ b/server/server_payments_test.go @@ -1,13 +1,17 @@ package server import ( + "encoding/json" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stripe/stripe-go/v74" "heckel.io/ntfy/user" "heckel.io/ntfy/util" "io" + "path/filepath" + "strings" "testing" + "time" ) func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) { @@ -70,8 +74,10 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) { u, err := s.userManager.User("phil") require.Nil(t, err) - u.Billing.StripeCustomerID = "acct_123" - require.Nil(t, s.userManager.ChangeBilling(u)) + billing := &user.Billing{ + StripeCustomerID: "acct_123", + } + require.Nil(t, s.userManager.ChangeBilling(u.Name, billing)) // Create subscription response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{ @@ -109,9 +115,11 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) { u, err := s.userManager.User("phil") require.Nil(t, err) - u.Billing.StripeCustomerID = "acct_123" - u.Billing.StripeSubscriptionID = "sub_123" - require.Nil(t, s.userManager.ChangeBilling(u)) + billing := &user.Billing{ + StripeCustomerID: "acct_123", + StripeSubscriptionID: "sub_123", + } + require.Nil(t, s.userManager.ChangeBilling(u.Name, billing)) // Delete account rr := request(t, s, "DELETE", "/v1/account", "", map[string]string{ @@ -125,6 +133,127 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) { require.Equal(t, 401, rr.Code) } +func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) { + // This tests incoming webhooks from Stripe to update a subscription: + // - All Stripe columns are updated in the user table + // - When downgrading, excess reservations are deleted, including messages and attachments in + // the corresponding topics + + stripeMock := &testStripeAPI{} + defer stripeMock.AssertExpectations(t) + + c := newTestConfigWithAuthFile(t) + c.StripeSecretKey = "secret key" + c.StripeWebhookKey = "webhook key" + s := newTestServer(t, c) + s.stripe = stripeMock + + // Define how the mock should react + stripeMock. + On("ConstructWebhookEvent", mock.Anything, "stripe signature", "webhook key"). + Return(jsonToStripeEvent(t, subscriptionUpdatedEventJSON), nil) + + // Create a user with a Stripe subscription and 3 reservations + require.Nil(t, s.userManager.CreateTier(&user.Tier{ + Code: "starter", + StripePriceID: "price_1234", // ! + ReservationsLimit: 1, // ! + MessagesLimit: 100, + MessagesExpiryDuration: time.Hour, + AttachmentExpiryDuration: time.Hour, + AttachmentFileSizeLimit: 1000000, + AttachmentTotalSizeLimit: 1000000, + })) + require.Nil(t, s.userManager.CreateTier(&user.Tier{ + Code: "pro", + StripePriceID: "price_1111", // ! + ReservationsLimit: 3, // ! + MessagesLimit: 200, + MessagesExpiryDuration: time.Hour, + AttachmentExpiryDuration: time.Hour, + AttachmentFileSizeLimit: 1000000, + AttachmentTotalSizeLimit: 1000000, + })) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) + require.Nil(t, s.userManager.ChangeTier("phil", "pro")) + require.Nil(t, s.userManager.ReserveAccess("phil", "atopic", user.PermissionDenyAll)) + require.Nil(t, s.userManager.ReserveAccess("phil", "ztopic", user.PermissionDenyAll)) + + // Add billing details + u, err := s.userManager.User("phil") + require.Nil(t, err) + + billing := &user.Billing{ + StripeCustomerID: "acct_5555", + StripeSubscriptionID: "sub_1234", + StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue, + StripeSubscriptionPaidUntil: time.Unix(123, 0), + StripeSubscriptionCancelAt: time.Unix(456, 0), + } + require.Nil(t, s.userManager.ChangeBilling(u.Name, billing)) + + // Add some messages to "atopic" and "ztopic", everything in "ztopic" will be deleted + rr := request(t, s, "PUT", "/atopic", "some aaa message", map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, rr.Code) + + rr = request(t, s, "PUT", "/atopic", strings.Repeat("a", 5000), map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, rr.Code) + a2 := toMessage(t, rr.Body.String()) + require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID)) + + rr = request(t, s, "PUT", "/ztopic", "some zzz message", map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, rr.Code) + + rr = request(t, s, "PUT", "/ztopic", strings.Repeat("z", 5000), map[string]string{ + "Authorization": util.BasicAuth("phil", "phil"), + }) + require.Equal(t, 200, rr.Code) + z2 := toMessage(t, rr.Body.String()) + require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID)) + + // Call the webhook: This does all the magic + rr = request(t, s, "POST", "/v1/account/billing/webhook", "dummy", map[string]string{ + "Stripe-Signature": "stripe signature", + }) + require.Equal(t, 200, rr.Code) + + // Verify that database columns were updated + u, err = s.userManager.User("phil") + require.Nil(t, err) + require.Equal(t, "starter", u.Tier.Code) // Not "pro" + require.Equal(t, "acct_5555", u.Billing.StripeCustomerID) + require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID) + require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) // Not "past_due" + require.Equal(t, int64(1674268231), u.Billing.StripeSubscriptionPaidUntil.Unix()) // Updated + require.Equal(t, int64(1674299999), u.Billing.StripeSubscriptionCancelAt.Unix()) // Updated + + // Verify that reservations were deleted + r, err := s.userManager.Reservations("phil") + require.Nil(t, err) + require.Equal(t, 1, len(r)) // "ztopic" reservation was deleted + require.Equal(t, "atopic", r[0].Topic) + + // Verify that messages and attachments were deleted + time.Sleep(time.Second) + s.execManager() + + ms, err := s.messageCache.Messages("atopic", sinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 2, len(ms)) + require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID)) + + ms, err = s.messageCache.Messages("ztopic", sinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 0, len(ms)) + require.NoFileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID)) +} + type testStripeAPI struct { mock.Mock } @@ -175,3 +304,34 @@ func (s *testStripeAPI) ConstructWebhookEvent(payload []byte, header string, sec } var _ stripeAPI = (*testStripeAPI)(nil) + +func jsonToStripeEvent(t *testing.T, v string) stripe.Event { + var e stripe.Event + if err := json.Unmarshal([]byte(v), &e); err != nil { + t.Fatal(err) + } + return e +} + +const subscriptionUpdatedEventJSON = ` +{ + "type": "customer.subscription.updated", + "data": { + "object": { + "id": "sub_1234", + "customer": "acct_5555", + "status": "active", + "current_period_end": 1674268231, + "cancel_at": 1674299999, + "items": { + "data": [ + { + "price": { + "id": "price_1234" + } + } + ] + } + } + } +}` diff --git a/server/server_test.go b/server/server_test.go index 6d69d6c6..4d32f409 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -640,7 +640,7 @@ func TestServer_Auth_Success_User(t *testing.T) { s := newTestServer(t, c) require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test")) - require.Nil(t, s.userManager.AllowAccess("", "ben", "mytopic", true, true)) + require.Nil(t, s.userManager.AllowAccess("ben", "mytopic", user.PermissionReadWrite)) response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{ "Authorization": util.BasicAuth("ben", "ben"), @@ -654,8 +654,8 @@ func TestServer_Auth_Success_User_MultipleTopics(t *testing.T) { s := newTestServer(t, c) require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test")) - require.Nil(t, s.userManager.AllowAccess("", "ben", "mytopic", true, true)) - require.Nil(t, s.userManager.AllowAccess("", "ben", "anothertopic", true, true)) + require.Nil(t, s.userManager.AllowAccess("ben", "mytopic", user.PermissionReadWrite)) + require.Nil(t, s.userManager.AllowAccess("ben", "anothertopic", user.PermissionReadWrite)) response := request(t, s, "GET", "/mytopic,anothertopic/auth", "", map[string]string{ "Authorization": util.BasicAuth("ben", "ben"), @@ -688,7 +688,7 @@ func TestServer_Auth_Fail_Unauthorized(t *testing.T) { s := newTestServer(t, c) require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test")) - require.Nil(t, s.userManager.AllowAccess("", "ben", "sometopic", true, true)) // Not mytopic! + require.Nil(t, s.userManager.AllowAccess("ben", "sometopic", user.PermissionReadWrite)) // Not mytopic! response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{ "Authorization": util.BasicAuth("ben", "ben"), @@ -702,8 +702,8 @@ func TestServer_Auth_Fail_CannotPublish(t *testing.T) { s := newTestServer(t, c) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test")) - require.Nil(t, s.userManager.AllowAccess("", user.Everyone, "private", false, false)) - require.Nil(t, s.userManager.AllowAccess("", user.Everyone, "announcements", true, false)) + require.Nil(t, s.userManager.AllowAccess(user.Everyone, "private", user.PermissionDenyAll)) + require.Nil(t, s.userManager.AllowAccess(user.Everyone, "announcements", user.PermissionRead)) response := request(t, s, "PUT", "/mytopic", "test", nil) require.Equal(t, 200, response.Code) @@ -750,7 +750,7 @@ func TestServer_StatsResetter(t *testing.T) { go s.runStatsResetter() require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) - require.Nil(t, s.userManager.AllowAccess("", "phil", "mytopic", true, true)) + require.Nil(t, s.userManager.AllowAccess("phil", "mytopic", user.PermissionReadWrite)) for i := 0; i < 5; i++ { response := request(t, s, "PUT", "/mytopic", "test", map[string]string{ diff --git a/server/visitor.go b/server/visitor.go index 5fd89ffa..c752de8e 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -16,6 +16,10 @@ const ( // has to be very high to prevent e-mail abuse, but it doesn't really affect the other limits anyway, since // they are replenished faster (typically). visitorExpungeAfter = 24 * time.Hour + + // visitorDefaultReservationsLimit is the amount of topic names a user without a tier is allowed to reserve. + // This number is zero, and changing it may have unintended consequences in the web app, or otherwise + visitorDefaultReservationsLimit = int64(0) ) var ( @@ -289,7 +293,7 @@ func defaultVisitorLimits(conf *Config) *visitorLimits { MessagesLimit: replenishDurationToDailyLimit(conf.VisitorRequestLimitReplenish), MessagesExpiryDuration: conf.CacheDuration, EmailsLimit: replenishDurationToDailyLimit(conf.VisitorEmailLimitReplenish), - ReservationsLimit: 0, // No reservations for anonymous users, or users without a tier + ReservationsLimit: visitorDefaultReservationsLimit, AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit, AttachmentFileSizeLimit: conf.AttachmentFileSizeLimit, AttachmentExpiryDuration: conf.AttachmentExpiryDuration, diff --git a/user/manager.go b/user/manager.go index 8fe2a0f7..652017e7 100644 --- a/user/manager.go +++ b/user/manager.go @@ -219,8 +219,7 @@ const ( INSERT INTO tier (code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ` - selectTierIDQuery = `SELECT id FROM tier WHERE code = ?` - selectTiersQuery = ` + selectTiersQuery = ` SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id FROM tier ` @@ -234,7 +233,7 @@ const ( FROM tier WHERE stripe_price_id = ? ` - updateUserTierQuery = `UPDATE user SET tier_id = ? WHERE user = ?` + updateUserTierQuery = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?` deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?` updateBillingQuery = ` @@ -772,26 +771,47 @@ func (a *Manager) ChangeRole(username string, role Role) error { return nil } -// ChangeTier changes a user's tier using the tier code +// ChangeTier changes a user's tier using the tier code. This function does not delete reservations, messages, +// or attachments, even if the new tier has lower limits in this regard. That has to be done elsewhere. func (a *Manager) ChangeTier(username, tier string) error { if !AllowedUsername(username) { return ErrInvalidArgument } - rows, err := a.db.Query(selectTierIDQuery, tier) + t, err := a.Tier(tier) + if err != nil { + return err + } else if err := a.checkReservationsLimit(username, t.ReservationsLimit); err != nil { + return err + } + if _, err := a.db.Exec(updateUserTierQuery, tier, username); err != nil { + return err + } + return nil +} + +// ResetTier removes the tier from the given user +func (a *Manager) ResetTier(username string) error { + if !AllowedUsername(username) && username != Everyone && username != "" { + return ErrInvalidArgument + } else if err := a.checkReservationsLimit(username, 0); err != nil { + return err + } + _, err := a.db.Exec(deleteUserTierQuery, username) + return err +} + +func (a *Manager) checkReservationsLimit(username string, reservationsLimit int64) error { + u, err := a.User(username) if err != nil { return err } - defer rows.Close() - if !rows.Next() { - return ErrInvalidArgument - } - var tierID int64 - if err := rows.Scan(&tierID); err != nil { - return err - } - rows.Close() - if _, err := a.db.Exec(updateUserTierQuery, tierID, username); err != nil { - return err + if u.Tier != nil && reservationsLimit < u.Tier.ReservationsLimit { + reservations, err := a.Reservations(username) + if err != nil { + return err + } else if int64(len(reservations)) > reservationsLimit { + return ErrTooManyReservations + } } return nil } @@ -823,20 +843,37 @@ func (a *Manager) CheckAllowAccess(username string, topic string) error { // AllowAccess adds or updates an entry in th access control list for a specific user. It controls // read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry // owner may either be a user (username), or the system (empty). -func (a *Manager) AllowAccess(owner, username string, topicPattern string, read bool, write bool) error { +func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error { if !AllowedUsername(username) && username != Everyone { return ErrInvalidArgument - } else if owner != "" && !AllowedUsername(owner) { - return ErrInvalidArgument } else if !AllowedTopicPattern(topicPattern) { return ErrInvalidArgument } - if _, err := a.db.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), read, write, owner, owner); err != nil { + owner := "" + if _, err := a.db.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), permission.IsRead(), permission.IsWrite(), owner, owner); err != nil { return err } return nil } +func (a *Manager) ReserveAccess(username string, topic string, everyone Permission) error { + if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) { + return ErrInvalidArgument + } + tx, err := a.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(upsertUserAccessQuery, username, topic, true, true, username, username); err != nil { + return err + } + if _, err := tx.Exec(upsertUserAccessQuery, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, username); err != nil { + return err + } + return tx.Commit() +} + // ResetAccess removes an access control list entry for a specific username/topic, or (if topic is // empty) for an entire user. The parameter topicPattern may include wildcards (*). func (a *Manager) ResetAccess(username string, topicPattern string) error { @@ -856,13 +893,29 @@ func (a *Manager) ResetAccess(username string, topicPattern string) error { return err } -// ResetTier removes the tier from the given user -func (a *Manager) ResetTier(username string) error { - if !AllowedUsername(username) && username != Everyone && username != "" { +func (a *Manager) RemoveReservations(username string, topics ...string) error { + if !AllowedUsername(username) || username == Everyone || len(topics) == 0 { return ErrInvalidArgument } - _, err := a.db.Exec(deleteUserTierQuery, username) - return err + for _, topic := range topics { + if !AllowedTopic(topic) { + return ErrInvalidArgument + } + } + tx, err := a.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + for _, topic := range topics { + if _, err := tx.Exec(deleteTopicAccessQuery, username, username, topic); err != nil { + return err + } + if _, err := tx.Exec(deleteTopicAccessQuery, Everyone, Everyone, topic); err != nil { + return err + } + } + return tx.Commit() } // DefaultAccess returns the default read/write access if no access control entry matches @@ -879,8 +932,8 @@ func (a *Manager) CreateTier(tier *Tier) error { } // ChangeBilling updates a user's billing fields, namely the Stripe customer ID, and subscription information -func (a *Manager) ChangeBilling(user *User) error { - if _, err := a.db.Exec(updateBillingQuery, nullString(user.Billing.StripeCustomerID), nullString(user.Billing.StripeSubscriptionID), nullString(string(user.Billing.StripeSubscriptionStatus)), nullInt64(user.Billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(user.Billing.StripeSubscriptionCancelAt.Unix()), user.Name); err != nil { +func (a *Manager) ChangeBilling(username string, billing *Billing) error { + if _, err := a.db.Exec(updateBillingQuery, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil { return err } return nil diff --git a/user/manager_test.go b/user/manager_test.go index 021ac470..fdbe5e0e 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -15,13 +15,13 @@ func TestManager_FullScenario_Default_DenyAll(t *testing.T) { a := newTestManager(t, PermissionDenyAll) require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, "unit-test")) require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) - require.Nil(t, a.AllowAccess("", "ben", "mytopic", true, true)) - require.Nil(t, a.AllowAccess("", "ben", "readme", true, false)) - require.Nil(t, a.AllowAccess("", "ben", "writeme", false, true)) - require.Nil(t, a.AllowAccess("", "ben", "everyonewrite", false, false)) // How unfair! - require.Nil(t, a.AllowAccess("", Everyone, "announcements", true, false)) - require.Nil(t, a.AllowAccess("", Everyone, "everyonewrite", true, true)) - require.Nil(t, a.AllowAccess("", Everyone, "up*", false, true)) // Everyone can write to /up* + require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite)) + require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead)) + require.Nil(t, a.AllowAccess("ben", "writeme", PermissionWrite)) + require.Nil(t, a.AllowAccess("ben", "everyonewrite", PermissionDenyAll)) // How unfair! + require.Nil(t, a.AllowAccess(Everyone, "announcements", PermissionRead)) + require.Nil(t, a.AllowAccess(Everyone, "everyonewrite", PermissionReadWrite)) + require.Nil(t, a.AllowAccess(Everyone, "up*", PermissionWrite)) // Everyone can write to /up* phil, err := a.Authenticate("phil", "phil") require.Nil(t, err) @@ -130,12 +130,12 @@ func TestManager_UserManagement(t *testing.T) { a := newTestManager(t, PermissionDenyAll) require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, "unit-test")) require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) - require.Nil(t, a.AllowAccess("", "ben", "mytopic", true, true)) - require.Nil(t, a.AllowAccess("", "ben", "readme", true, false)) - require.Nil(t, a.AllowAccess("", "ben", "writeme", false, true)) - require.Nil(t, a.AllowAccess("", "ben", "everyonewrite", false, false)) // How unfair! - require.Nil(t, a.AllowAccess("", Everyone, "announcements", true, false)) - require.Nil(t, a.AllowAccess("", Everyone, "everyonewrite", true, true)) + require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite)) + require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead)) + require.Nil(t, a.AllowAccess("ben", "writeme", PermissionWrite)) + require.Nil(t, a.AllowAccess("ben", "everyonewrite", PermissionDenyAll)) // How unfair! + require.Nil(t, a.AllowAccess(Everyone, "announcements", PermissionRead)) + require.Nil(t, a.AllowAccess(Everyone, "everyonewrite", PermissionReadWrite)) // Query user details phil, err := a.User("phil") @@ -177,9 +177,9 @@ func TestManager_UserManagement(t *testing.T) { }, everyoneGrants) // Ben: Before revoking - require.Nil(t, a.AllowAccess("", "ben", "mytopic", true, true)) // Overwrite! - require.Nil(t, a.AllowAccess("", "ben", "readme", true, false)) - require.Nil(t, a.AllowAccess("", "ben", "writeme", false, true)) + require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite)) // Overwrite! + require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead)) + require.Nil(t, a.AllowAccess("ben", "writeme", PermissionWrite)) require.Nil(t, a.Authorize(ben, "mytopic", PermissionRead)) require.Nil(t, a.Authorize(ben, "mytopic", PermissionWrite)) require.Nil(t, a.Authorize(ben, "readme", PermissionRead)) @@ -234,8 +234,8 @@ func TestManager_ChangePassword(t *testing.T) { func TestManager_ChangeRole(t *testing.T) { a := newTestManager(t, PermissionDenyAll) require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) - require.Nil(t, a.AllowAccess("", "ben", "mytopic", true, true)) - require.Nil(t, a.AllowAccess("", "ben", "readme", true, false)) + require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite)) + require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead)) ben, err := a.User("ben") require.Nil(t, err) @@ -256,6 +256,28 @@ func TestManager_ChangeRole(t *testing.T) { require.Equal(t, 0, len(benGrants)) } +func TestManager_Reservations(t *testing.T) { + a := newTestManager(t, PermissionDenyAll) + require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) + require.Nil(t, a.ReserveAccess("ben", "ztopic", PermissionDenyAll)) + require.Nil(t, a.ReserveAccess("ben", "readme", PermissionRead)) + require.Nil(t, a.AllowAccess("ben", "something-else", PermissionRead)) + + reservations, err := a.Reservations("ben") + require.Nil(t, err) + require.Equal(t, 2, len(reservations)) + require.Equal(t, Reservation{ + Topic: "readme", + Owner: PermissionReadWrite, + Everyone: PermissionRead, + }, reservations[0]) + require.Equal(t, Reservation{ + Topic: "ztopic", + Owner: PermissionReadWrite, + Everyone: PermissionDenyAll, + }, reservations[1]) +} + func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) { a := newTestManager(t, PermissionDenyAll) require.Nil(t, a.CreateTier(&Tier{ @@ -272,8 +294,7 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) { })) require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) require.Nil(t, a.ChangeTier("ben", "pro")) - require.Nil(t, a.AllowAccess("ben", "ben", "mytopic", true, true)) - require.Nil(t, a.AllowAccess("ben", Everyone, "mytopic", false, false)) + require.Nil(t, a.ReserveAccess("ben", "mytopic", PermissionDenyAll)) ben, err := a.User("ben") require.Nil(t, err) @@ -298,6 +319,13 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) { require.Equal(t, 1, len(everyoneGrants)) require.Equal(t, PermissionDenyAll, everyoneGrants[0].Allow) + benReservations, err := a.Reservations("ben") + require.Nil(t, err) + require.Equal(t, 1, len(benReservations)) + require.Equal(t, "mytopic", benReservations[0].Topic) + require.Equal(t, PermissionReadWrite, benReservations[0].Owner) + require.Equal(t, PermissionDenyAll, benReservations[0].Everyone) + // Switch to admin, this should remove all grants and owned ACL entries require.Nil(t, a.ChangeRole("ben", RoleAdmin)) diff --git a/user/types.go b/user/types.go index 2aca5652..5e95ad56 100644 --- a/user/types.go +++ b/user/types.go @@ -221,19 +221,10 @@ func AllowedTier(tier string) bool { // Error constants used by the package var ( - ErrUnauthenticated = errors.New("unauthenticated") - ErrUnauthorized = errors.New("unauthorized") - ErrInvalidArgument = errors.New("invalid argument") - ErrUserNotFound = errors.New("user not found") - ErrTierNotFound = errors.New("tier not found") -) - -// BillingStatus represents the status of a Stripe subscription -type BillingStatus string - -// BillingStatus values, subset of https://stripe.com/docs/billing/subscriptions/overview -const ( - BillingStatusIncomplete = BillingStatus("incomplete") - BillingStatusActive = BillingStatus("active") - BillingStatusPastDue = BillingStatus("past_due") + ErrUnauthenticated = errors.New("unauthenticated") + ErrUnauthorized = errors.New("unauthorized") + ErrInvalidArgument = errors.New("invalid argument") + ErrUserNotFound = errors.New("user not found") + ErrTierNotFound = errors.New("tier not found") + ErrTooManyReservations = errors.New("new tier has lower reservation limit") ) diff --git a/web/public/static/langs/en.json b/web/public/static/langs/en.json index a96e1072..cac86d87 100644 --- a/web/public/static/langs/en.json +++ b/web/public/static/langs/en.json @@ -201,15 +201,19 @@ "account_delete_dialog_label": "Type '{{username}}' to delete account", "account_delete_dialog_button_cancel": "Cancel", "account_delete_dialog_button_submit": "Permanently delete account", + "account_delete_dialog_billing_warning": "Deleting your account also cancels your billing subscription immediately. You will not have access to the billing dashboard anymore.", "account_upgrade_dialog_title": "Change account tier", "account_upgrade_dialog_cancel_warning": "This will cancel your subscription, and downgrade your account on {{date}}. On that date, topic reservations as well as messages cached on the server will be deleted.", "account_upgrade_dialog_proration_info": "Proration: When switching between paid plans, the price difference will be charged or refunded in the next invoice. You will not receive another invoice until the end of the next billing period.", + "account_upgrade_dialog_reservations_warning_one": "The selected tier allows fewer reserved topics than your current tier. Before changing your tier, please delete at least one reservation. You can remove reservations in the Settings.", + "account_upgrade_dialog_reservations_warning_other": "The selected tier allows fewer reserved topics than your current tier. Before changing your tier, please delete at least {{count}} reservations. You can remove reservations in the Settings.", "account_upgrade_dialog_tier_features_reservations": "{{reservations}} reserved topics", "account_upgrade_dialog_tier_features_messages": "{{messages}} daily messages", "account_upgrade_dialog_tier_features_emails": "{{emails}} daily emails", "account_upgrade_dialog_tier_features_attachment_file_size": "{{filesize}} per file", "account_upgrade_dialog_tier_features_attachment_total_size": "{{totalsize}} total storage", "account_upgrade_dialog_tier_selected_label": "Selected", + "account_upgrade_dialog_tier_current_label": "Current", "account_upgrade_dialog_button_cancel": "Cancel", "account_upgrade_dialog_button_redirect_signup": "Sign up now", "account_upgrade_dialog_button_pay_now": "Pay now and subscribe", diff --git a/web/src/components/Account.js b/web/src/components/Account.js index f451e2be..452868b7 100644 --- a/web/src/components/Account.js +++ b/web/src/components/Account.js @@ -264,7 +264,6 @@ const AccountType = () => { const Stats = () => { const { t } = useTranslation(); const { account } = useContext(AccountContext); - const [upgradeDialogOpen, setUpgradeDialogOpen] = useState(false); if (!account) { return <>; @@ -435,6 +434,7 @@ const DeleteAccount = () => { const DeleteAccountDialog = (props) => { const { t } = useTranslation(); + const { account } = useContext(AccountContext); const [username, setUsername] = useState(""); const fullScreen = useMediaQuery(theme.breakpoints.down('sm')); const buttonEnabled = username === session.username(); @@ -456,6 +456,9 @@ const DeleteAccountDialog = (props) => { fullWidth variant="standard" /> + {account?.billing?.subscription && + {t("account_delete_dialog_billing_warning")} + } diff --git a/web/src/components/Preferences.js b/web/src/components/Preferences.js index 8e2c8bf9..23224fcd 100644 --- a/web/src/components/Preferences.js +++ b/web/src/components/Preferences.js @@ -3,7 +3,7 @@ import {useContext, useEffect, useState} from 'react'; import { Alert, CardActions, - CardContent, + CardContent, Chip, FormControl, Select, Stack, @@ -20,6 +20,7 @@ import prefs from "../app/Prefs"; import {Paragraph} from "./styles"; import EditIcon from '@mui/icons-material/Edit'; import CloseIcon from "@mui/icons-material/Close"; +import WarningIcon from '@mui/icons-material/Warning'; import IconButton from "@mui/material/IconButton"; import PlayArrowIcon from '@mui/icons-material/PlayArrow'; import Container from "@mui/material/Container"; @@ -41,10 +42,12 @@ import routes from "./routes"; import accountApi, {UnauthorizedError} from "../app/AccountApi"; import {Pref, PrefGroup} from "./Pref"; import LockIcon from "@mui/icons-material/Lock"; -import {Public, PublicOff} from "@mui/icons-material"; +import {Check, Info, Public, PublicOff} from "@mui/icons-material"; import DialogContentText from "@mui/material/DialogContentText"; import ReserveTopicSelect from "./ReserveTopicSelect"; import {AccountContext} from "./App"; +import {useOutletContext} from "react-router-dom"; +import subscriptionManager from "../app/SubscriptionManager"; const Preferences = () => { return ( @@ -543,6 +546,12 @@ const ReservationsTable = (props) => { const [dialogKey, setDialogKey] = useState(0); const [dialogOpen, setDialogOpen] = useState(false); const [dialogReservation, setDialogReservation] = useState(null); + const { subscriptions } = useOutletContext(); + const localSubscriptions = Object.assign( + ...subscriptions + .filter(s => s.baseUrl === config.base_url) + .map(s => ({[s.topic]: s})) + ); const handleEditClick = (reservation) => { setDialogKey(prev => prev+1); @@ -592,7 +601,9 @@ const ReservationsTable = (props) => { key={reservation.topic} sx={{'&:last-child td, &:last-child th': {border: 0}}} > - {reservation.topic} + + {reservation.topic} + {reservation.everyone === "read-write" && <> @@ -620,6 +631,9 @@ const ReservationsTable = (props) => { } + {!localSubscriptions[reservation.topic] && + } label="Not subscribed" color="primary" variant="outlined"/> + } handleEditClick(reservation)} aria-label={t("prefs_reservations_edit_button")}> diff --git a/web/src/components/UpgradeDialog.js b/web/src/components/UpgradeDialog.js index 798f21b2..5e2a068a 100644 --- a/web/src/components/UpgradeDialog.js +++ b/web/src/components/UpgradeDialog.js @@ -21,13 +21,14 @@ import {Check} from "@mui/icons-material"; import ListItemIcon from "@mui/material/ListItemIcon"; import ListItemText from "@mui/material/ListItemText"; import Box from "@mui/material/Box"; +import {NavLink} from "react-router-dom"; const UpgradeDialog = (props) => { const { t } = useTranslation(); const { account } = useContext(AccountContext); // May be undefined! const fullScreen = useMediaQuery(theme.breakpoints.down('sm')); const [tiers, setTiers] = useState(null); - const [newTier, setNewTier] = useState(account?.tier?.code); // May be undefined + const [newTierCode, setNewTierCode] = useState(account?.tier?.code); // May be undefined const [loading, setLoading] = useState(false); const [errorText, setErrorText] = useState(""); @@ -41,47 +42,56 @@ const UpgradeDialog = (props) => { return <>; } - const currentTier = account?.tier?.code; // May be undefined - let action, submitButtonLabel, submitButtonEnabled; + const tiersMap = Object.assign(...tiers.map(tier => ({[tier.code]: tier}))); + const newTier = tiersMap[newTierCode]; // May be undefined + const currentTier = account?.tier; // May be undefined + const currentTierCode = currentTier?.code; // May be undefined + + // Figure out buttons, labels and the submit action + let submitAction, submitButtonLabel, banner; if (!account) { submitButtonLabel = t("account_upgrade_dialog_button_redirect_signup"); - submitButtonEnabled = true; - action = Action.REDIRECT_SIGNUP; - } else if (currentTier === newTier) { + submitAction = Action.REDIRECT_SIGNUP; + banner = null; + } else if (currentTierCode === newTierCode) { submitButtonLabel = t("account_upgrade_dialog_button_update_subscription"); - submitButtonEnabled = false; - action = null; - } else if (!currentTier) { + submitAction = null; + banner = (currentTierCode) ? Banner.PRORATION_INFO : null; + } else if (!currentTierCode) { submitButtonLabel = t("account_upgrade_dialog_button_pay_now"); - submitButtonEnabled = true; - action = Action.CREATE_SUBSCRIPTION; - } else if (!newTier) { + submitAction = Action.CREATE_SUBSCRIPTION; + banner = null; + } else if (!newTierCode) { submitButtonLabel = t("account_upgrade_dialog_button_cancel_subscription"); - submitButtonEnabled = true; - action = Action.CANCEL_SUBSCRIPTION; + submitAction = Action.CANCEL_SUBSCRIPTION; + banner = Banner.CANCEL_WARNING; } else { submitButtonLabel = t("account_upgrade_dialog_button_update_subscription"); - submitButtonEnabled = true; - action = Action.UPDATE_SUBSCRIPTION; + submitAction = Action.UPDATE_SUBSCRIPTION; + banner = Banner.PRORATION_INFO; } + // Exceptional conditions if (loading) { - submitButtonEnabled = false; + submitAction = null; + } else if (newTier?.code && account?.reservations.length > newTier?.limits.reservations) { + submitAction = null; + banner = Banner.RESERVATIONS_WARNING; } const handleSubmit = async () => { - if (action === Action.REDIRECT_SIGNUP) { + if (submitAction === Action.REDIRECT_SIGNUP) { window.location.href = routes.signup; return; } try { setLoading(true); - if (action === Action.CREATE_SUBSCRIPTION) { - const response = await accountApi.createBillingSubscription(newTier); + if (submitAction === Action.CREATE_SUBSCRIPTION) { + const response = await accountApi.createBillingSubscription(newTierCode); window.location.href = response.redirect_url; - } else if (action === Action.UPDATE_SUBSCRIPTION) { - await accountApi.updateBillingSubscription(newTier); - } else if (action === Action.CANCEL_SUBSCRIPTION) { + } else if (submitAction === Action.UPDATE_SUBSCRIPTION) { + await accountApi.updateBillingSubscription(newTierCode); + } else if (submitAction === Action.CANCEL_SUBSCRIPTION) { await accountApi.deleteBillingSubscription(); } props.onCancel(); @@ -116,27 +126,39 @@ const UpgradeDialog = (props) => { setNewTier(tier.code)} // tier.code may be undefined! + current={currentTierCode === tier.code} // tier.code or currentTierCode may be undefined! + selected={newTierCode === tier.code} // tier.code may be undefined! + onClick={() => setNewTierCode(tier.code)} // tier.code may be undefined! /> )} - {action === Action.CANCEL_SUBSCRIPTION && + {banner === Banner.CANCEL_WARNING && } - {currentTier && (!action || action === Action.UPDATE_SUBSCRIPTION) && + {banner === Banner.PRORATION_INFO && } + {banner === Banner.RESERVATIONS_WARNING && + + , + }} + /> + + } - + ); @@ -144,8 +166,19 @@ const UpgradeDialog = (props) => { const TierCard = (props) => { const { t } = useTranslation(); - const cardStyle = (props.selected) ? { background: "#eee", border: "2px solid #338574" } : { border: "2px solid transparent" }; const tier = props.tier; + let cardStyle, labelStyle, labelText; + if (props.selected) { + cardStyle = { background: "#eee", border: "2px solid #338574" }; + labelStyle = { background: "#338574", color: "white" }; + labelText = t("account_upgrade_dialog_tier_selected_label"); + } else if (props.current) { + cardStyle = { border: "2px solid #eee" }; + labelStyle = { background: "#eee", color: "black" }; + labelText = t("account_upgrade_dialog_tier_current_label"); + } else { + cardStyle = { border: "2px solid transparent" }; + } return ( { - {props.selected && + {labelStyle &&
{t("account_upgrade_dialog_tier_selected_label")}
+ ...labelStyle + }}>{labelText} } {tier.name || t("account_usage_tier_free")} @@ -217,10 +249,17 @@ const FeatureItem = (props) => { }; const Action = { - REDIRECT_SIGNUP: 0, - CREATE_SUBSCRIPTION: 1, - UPDATE_SUBSCRIPTION: 2, - CANCEL_SUBSCRIPTION: 3 + REDIRECT_SIGNUP: 1, + CREATE_SUBSCRIPTION: 2, + UPDATE_SUBSCRIPTION: 3, + CANCEL_SUBSCRIPTION: 4 }; +const Banner = { + CANCEL_WARNING: 1, + PRORATION_INFO: 2, + RESERVATIONS_WARNING: 3 +}; + + export default UpgradeDialog;