mirror of
https://github.com/binwiederhier/ntfy.git
synced 2024-11-24 20:29:21 +01:00
Payments webhook test, delete attachments/messages when reservations are removed,
This commit is contained in:
parent
45b97c7054
commit
31a3bb7cd6
16 changed files with 571 additions and 157 deletions
|
@ -100,22 +100,24 @@ func changeAccess(c *cli.Context, manager *user.Manager, username string, topic
|
|||
if !util.Contains([]string{"", "read-write", "rw", "read-only", "read", "ro", "write-only", "write", "wo", "none", "deny"}, perms) {
|
||||
return errors.New("permission must be one of: read-write, read-only, write-only, or deny (or the aliases: read, ro, write, wo, none)")
|
||||
}
|
||||
read := util.Contains([]string{"read-write", "rw", "read-only", "read", "ro"}, perms)
|
||||
write := util.Contains([]string{"read-write", "rw", "write-only", "write", "wo"}, perms)
|
||||
permission, err := user.ParsePermission(perms)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u, err := manager.User(username)
|
||||
if err == user.ErrUserNotFound {
|
||||
return fmt.Errorf("user %s does not exist", username)
|
||||
} else if u.Role == user.RoleAdmin {
|
||||
return fmt.Errorf("user %s is an admin user, access control entries have no effect", username)
|
||||
}
|
||||
if err := manager.AllowAccess("", username, topic, read, write); err != nil {
|
||||
if err := manager.AllowAccess(username, topic, permission); err != nil {
|
||||
return err
|
||||
}
|
||||
if read && write {
|
||||
if permission.IsReadWrite() {
|
||||
fmt.Fprintf(c.App.ErrWriter, "granted read-write access to topic %s\n\n", topic)
|
||||
} else if read {
|
||||
} else if permission.IsRead() {
|
||||
fmt.Fprintf(c.App.ErrWriter, "granted read-only access to topic %s\n\n", topic)
|
||||
} else if write {
|
||||
} else if permission.IsWrite() {
|
||||
fmt.Fprintf(c.App.ErrWriter, "granted write-only access to topic %s\n\n", topic)
|
||||
} else {
|
||||
fmt.Fprintf(c.App.ErrWriter, "revoked all access to topic %s\n\n", topic)
|
||||
|
|
|
@ -58,6 +58,7 @@ const (
|
|||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
deleteMessageQuery = `DELETE FROM messages WHERE mid = ?`
|
||||
updateMessagesForTopicExpiryQuery = `UPDATE messages SET expires = ? WHERE topic = ?`
|
||||
selectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?` // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics
|
||||
selectMessagesSinceTimeQuery = `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
|
||||
|
@ -96,7 +97,7 @@ const (
|
|||
selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic`
|
||||
|
||||
updateAttachmentDeleted = `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?`
|
||||
selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires <= ? AND attachment_deleted = 0`
|
||||
selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0`
|
||||
selectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?`
|
||||
selectAttachmentsSizeByUserQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`
|
||||
)
|
||||
|
@ -506,6 +507,20 @@ func (c *messageCache) DeleteMessages(ids ...string) error {
|
|||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (c *messageCache) ExpireMessages(topics ...string) error {
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
for _, t := range topics {
|
||||
if _, err := tx.Exec(updateMessagesForTopicExpiryQuery, time.Now().Unix(), t); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (c *messageCache) AttachmentsExpired() ([]string, error) {
|
||||
rows, err := c.db.Query(selectAttachmentsExpiredQuery, time.Now().Unix())
|
||||
if err != nil {
|
||||
|
|
|
@ -362,6 +362,61 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
|
|||
require.Equal(t, int64(0), size)
|
||||
}
|
||||
|
||||
func TestSqliteCache_Attachments_Expired(t *testing.T) {
|
||||
testCacheAttachmentsExpired(t, newSqliteTestCache(t))
|
||||
}
|
||||
|
||||
func TestMemCache_Attachments_Expired(t *testing.T) {
|
||||
testCacheAttachmentsExpired(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCacheAttachmentsExpired(t *testing.T, c *messageCache) {
|
||||
m := newDefaultMessage("mytopic", "flower for you")
|
||||
m.ID = "m1"
|
||||
m.Expires = time.Now().Add(time.Hour).Unix()
|
||||
require.Nil(t, c.AddMessage(m))
|
||||
|
||||
m = newDefaultMessage("mytopic", "message with attachment")
|
||||
m.ID = "m2"
|
||||
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
||||
m.Attachment = &attachment{
|
||||
Name: "car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 10000,
|
||||
Expires: time.Now().Add(2 * time.Hour).Unix(),
|
||||
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
||||
}
|
||||
require.Nil(t, c.AddMessage(m))
|
||||
|
||||
m = newDefaultMessage("mytopic", "message with external attachment")
|
||||
m.ID = "m3"
|
||||
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
||||
m.Attachment = &attachment{
|
||||
Name: "car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Expires: 0, // Unknown!
|
||||
URL: "https://somedomain.com/car.jpg",
|
||||
}
|
||||
require.Nil(t, c.AddMessage(m))
|
||||
|
||||
m = newDefaultMessage("mytopic2", "message with expired attachment")
|
||||
m.ID = "m4"
|
||||
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
||||
m.Attachment = &attachment{
|
||||
Name: "expired-car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 20000,
|
||||
Expires: time.Now().Add(-1 * time.Hour).Unix(),
|
||||
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
||||
}
|
||||
require.Nil(t, c.AddMessage(m))
|
||||
|
||||
ids, err := c.AttachmentsExpired()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(ids))
|
||||
require.Equal(t, "m4", ids[0])
|
||||
}
|
||||
|
||||
func TestSqliteCache_Migration_From0(t *testing.T) {
|
||||
filename := newSqliteTestCacheFile(t)
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
|
|
|
@ -40,13 +40,15 @@ import (
|
|||
- v.user --> see publishSyncEventAsync() test
|
||||
|
||||
payments:
|
||||
- delete messages + reserved topics on ResetTier
|
||||
- delete messages + reserved topics on ResetTier delete attachments in access.go
|
||||
- reconciliation
|
||||
|
||||
Limits & rate limiting:
|
||||
users without tier: should the stats be persisted? are they meaningful?
|
||||
-> test that the visitor is based on the IP address!
|
||||
users without tier: should the stats be persisted? are they meaningful? -> test that the visitor is based on the IP address!
|
||||
login/account endpoints
|
||||
when ResetStats() is run, reset messagesLimiter (and others)?
|
||||
Delete visitor when tier is changed to refresh rate limiters
|
||||
|
||||
Make sure account endpoints make sense for admins
|
||||
|
||||
UI:
|
||||
|
@ -55,10 +57,9 @@ import (
|
|||
- JS constants
|
||||
Sync:
|
||||
- sync problems with "deleteAfter=0" and "displayName="
|
||||
Delete visitor when tier is changed to refresh rate limiters
|
||||
|
||||
Tests:
|
||||
- Payment endpoints (make mocks)
|
||||
- Change tier from higher to lower tier (delete reservations)
|
||||
- Message rate limiting and reset tests
|
||||
- test that the visitor is based on the IP address when a user has no tier
|
||||
*/
|
||||
|
|
|
@ -119,7 +119,7 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis
|
|||
}
|
||||
|
||||
func (s *Server) handleAccountDelete(w http.ResponseWriter, _ *http.Request, v *visitor) error {
|
||||
if v.user.Billing.StripeCustomerID != "" {
|
||||
if v.user.Billing.StripeSubscriptionID != "" {
|
||||
log.Info("Deleting user %s (billing customer: %s, billing subscription: %s)", v.user.Name, v.user.Billing.StripeCustomerID, v.user.Billing.StripeSubscriptionID)
|
||||
if v.user.Billing.StripeSubscriptionID != "" {
|
||||
if _, err := s.stripe.CancelSubscription(v.user.Billing.StripeSubscriptionID); err != nil {
|
||||
|
@ -332,11 +332,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
|
|||
return errHTTPTooManyRequestsLimitReservations
|
||||
}
|
||||
}
|
||||
owner, username := v.user.Name, v.user.Name
|
||||
if err := s.userManager.AllowAccess(owner, username, req.Topic, true, true); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.userManager.AllowAccess(owner, user.Everyone, req.Topic, everyone.IsRead(), everyone.IsWrite()); err != nil {
|
||||
if err := s.userManager.ReserveAccess(v.user.Name, req.Topic, everyone); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
|
@ -357,10 +353,7 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R
|
|||
} else if !authorized {
|
||||
return errHTTPUnauthorized
|
||||
}
|
||||
if err := s.userManager.ResetAccess(v.user.Name, topic); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.userManager.ResetAccess(user.Everyone, topic); err != nil {
|
||||
if err := s.userManager.RemoveReservations(v.user.Name, topic); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
|
|
|
@ -27,6 +27,28 @@ var (
|
|||
errNoBillingSubscription = errors.New("user does not have an active billing subscription")
|
||||
)
|
||||
|
||||
// Payments in ntfy are done via Stripe.
|
||||
//
|
||||
// Pretty much all payments related things are in this file. The following processes
|
||||
// handle payments:
|
||||
//
|
||||
// - Checkout:
|
||||
// Creating a Stripe customer and subscription via the Checkout flow. This flow is only used if the
|
||||
// ntfy user is not already a Stripe customer. This requires redirecting to the Stripe checkout page.
|
||||
// It is implemented in handleAccountBillingSubscriptionCreate and the success callback
|
||||
// handleAccountBillingSubscriptionCreateSuccess.
|
||||
// - Update subscription:
|
||||
// Switching between Stripe subscriptions (upgrade/downgrade) is handled via
|
||||
// handleAccountBillingSubscriptionUpdate. This also handles proration.
|
||||
// - Cancel subscription (at period end):
|
||||
// Users can cancel the Stripe subscription via the web app at the end of the billing period. This
|
||||
// simply updates the subscription and Stripe will cancel it. Users cannot immediately cancel the
|
||||
// subscription.
|
||||
// - Webhooks:
|
||||
// Whenever a subscription changes (updated, deleted), Stripe sends us a request via a webhook.
|
||||
// This is used to keep the local user database fields up to date. Stripe is the source of truth.
|
||||
// What Stripe says is mirrored and not questioned.
|
||||
|
||||
// handleBillingTiersGet returns all available paid tiers, and the free tier. This is to populate the upgrade dialog
|
||||
// in the UI. Note that this endpoint does NOT have a user context (no v.user!).
|
||||
func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
|
||||
|
@ -37,7 +59,7 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
|
|||
freeTier := defaultVisitorLimits(s.config)
|
||||
response := []*apiAccountBillingTier{
|
||||
{
|
||||
// Free tier: no code, name or price
|
||||
// This is a bit of a hack: This is the "Free" tier. It has no tier code, name or price.
|
||||
Limits: &apiAccountLimits{
|
||||
Messages: freeTier.MessagesLimit,
|
||||
MessagesExpiryDuration: int64(freeTier.MessagesExpiryDuration.Seconds()),
|
||||
|
@ -130,6 +152,9 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
|
|||
return s.writeJSON(w, response)
|
||||
}
|
||||
|
||||
// handleAccountBillingSubscriptionCreateSuccess is called after the Stripe checkout session has succeeded. We use
|
||||
// the session ID in the URL to retrieve the Stripe subscription and update the local database. This is the first
|
||||
// and only time we can map the local username with the Stripe customer ID.
|
||||
func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error {
|
||||
// We don't have a v.user in this endpoint, only a userManager!
|
||||
matches := apiAccountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
|
||||
|
@ -139,8 +164,7 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
|
|||
sessionID := matches[1]
|
||||
sess, err := s.stripe.GetSession(sessionID) // FIXME How do we rate limit this?
|
||||
if err != nil {
|
||||
log.Warn("Stripe: %s", err)
|
||||
return errHTTPBadRequestBillingRequestInvalid
|
||||
return err
|
||||
} else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" {
|
||||
return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "customer or subscription not found")
|
||||
}
|
||||
|
@ -158,7 +182,7 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.updateSubscriptionAndTier(u, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt, tier.Code); err != nil {
|
||||
if err := s.updateSubscriptionAndTier(u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
|
||||
return err
|
||||
}
|
||||
http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther)
|
||||
|
@ -216,6 +240,8 @@ func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r
|
|||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
// handleAccountBillingPortalSessionCreate creates a session to the customer billing portal, and returns the
|
||||
// redirect URL. The billing portal allows customers to change their payment methods, and cancel the subscription.
|
||||
func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
if v.user.Billing.StripeCustomerID == "" {
|
||||
return errHTTPBadRequestNotAPaidUser
|
||||
|
@ -250,10 +276,11 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ
|
|||
}
|
||||
event, err := s.stripe.ConstructWebhookEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey)
|
||||
if err != nil {
|
||||
return errHTTPBadRequestBillingRequestInvalid
|
||||
return err
|
||||
} else if event.Data == nil || event.Data.Raw == nil {
|
||||
return errHTTPBadRequestBillingRequestInvalid
|
||||
}
|
||||
|
||||
log.Info("Stripe: webhook event %s received", event.Type)
|
||||
switch event.Type {
|
||||
case "customer.subscription.updated":
|
||||
|
@ -282,7 +309,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.updateSubscriptionAndTier(u, r.Customer, subscriptionID, r.Status, r.CurrentPeriodEnd, r.CancelAt, tier.Code); err != nil {
|
||||
if err := s.updateSubscriptionAndTier(u, tier, r.Customer, subscriptionID, r.Status, r.CurrentPeriodEnd, r.CancelAt); err != nil {
|
||||
return err
|
||||
}
|
||||
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
|
||||
|
@ -301,29 +328,54 @@ func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMe
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.updateSubscriptionAndTier(u, r.Customer, "", "", 0, 0, ""); err != nil {
|
||||
if err := s.updateSubscriptionAndTier(u, nil, r.Customer, "", "", 0, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) updateSubscriptionAndTier(u *user.User, customerID, subscriptionID, status string, paidUntil, cancelAt int64, tier string) error {
|
||||
u.Billing.StripeCustomerID = customerID
|
||||
u.Billing.StripeSubscriptionID = subscriptionID
|
||||
u.Billing.StripeSubscriptionStatus = stripe.SubscriptionStatus(status)
|
||||
u.Billing.StripeSubscriptionPaidUntil = time.Unix(paidUntil, 0)
|
||||
u.Billing.StripeSubscriptionCancelAt = time.Unix(cancelAt, 0)
|
||||
if tier == "" {
|
||||
func (s *Server) updateSubscriptionAndTier(u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
|
||||
// Remove excess reservations (if too many for tier), and mark associated messages deleted
|
||||
reservations, err := s.userManager.Reservations(u.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
reservationsLimit := visitorDefaultReservationsLimit
|
||||
if tier != nil {
|
||||
reservationsLimit = tier.ReservationsLimit
|
||||
}
|
||||
if int64(len(reservations)) > reservationsLimit {
|
||||
topics := make([]string, 0)
|
||||
for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- {
|
||||
topics = append(topics, reservations[i].Topic)
|
||||
}
|
||||
if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.messageCache.ExpireMessages(topics...); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Change or remove tier
|
||||
if tier == nil {
|
||||
if err := s.userManager.ResetTier(u.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := s.userManager.ChangeTier(u.Name, tier); err != nil {
|
||||
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := s.userManager.ChangeBilling(u); err != nil {
|
||||
// Update billing fields
|
||||
billing := &user.Billing{
|
||||
StripeCustomerID: customerID,
|
||||
StripeSubscriptionID: subscriptionID,
|
||||
StripeSubscriptionStatus: stripe.SubscriptionStatus(status),
|
||||
StripeSubscriptionPaidUntil: time.Unix(paidUntil, 0),
|
||||
StripeSubscriptionCancelAt: time.Unix(cancelAt, 0),
|
||||
}
|
||||
if err := s.userManager.ChangeBilling(u.Name, billing); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -1,13 +1,17 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stripe/stripe-go/v74"
|
||||
"heckel.io/ntfy/user"
|
||||
"heckel.io/ntfy/util"
|
||||
"io"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
|
||||
|
@ -70,8 +74,10 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
|
|||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
u.Billing.StripeCustomerID = "acct_123"
|
||||
require.Nil(t, s.userManager.ChangeBilling(u))
|
||||
billing := &user.Billing{
|
||||
StripeCustomerID: "acct_123",
|
||||
}
|
||||
require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
|
||||
|
||||
// Create subscription
|
||||
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
|
||||
|
@ -109,9 +115,11 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
|
|||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
u.Billing.StripeCustomerID = "acct_123"
|
||||
u.Billing.StripeSubscriptionID = "sub_123"
|
||||
require.Nil(t, s.userManager.ChangeBilling(u))
|
||||
billing := &user.Billing{
|
||||
StripeCustomerID: "acct_123",
|
||||
StripeSubscriptionID: "sub_123",
|
||||
}
|
||||
require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
|
||||
|
||||
// Delete account
|
||||
rr := request(t, s, "DELETE", "/v1/account", "", map[string]string{
|
||||
|
@ -125,6 +133,127 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
|
|||
require.Equal(t, 401, rr.Code)
|
||||
}
|
||||
|
||||
func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
|
||||
// This tests incoming webhooks from Stripe to update a subscription:
|
||||
// - All Stripe columns are updated in the user table
|
||||
// - When downgrading, excess reservations are deleted, including messages and attachments in
|
||||
// the corresponding topics
|
||||
|
||||
stripeMock := &testStripeAPI{}
|
||||
defer stripeMock.AssertExpectations(t)
|
||||
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.StripeSecretKey = "secret key"
|
||||
c.StripeWebhookKey = "webhook key"
|
||||
s := newTestServer(t, c)
|
||||
s.stripe = stripeMock
|
||||
|
||||
// Define how the mock should react
|
||||
stripeMock.
|
||||
On("ConstructWebhookEvent", mock.Anything, "stripe signature", "webhook key").
|
||||
Return(jsonToStripeEvent(t, subscriptionUpdatedEventJSON), nil)
|
||||
|
||||
// Create a user with a Stripe subscription and 3 reservations
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
Code: "starter",
|
||||
StripePriceID: "price_1234", // !
|
||||
ReservationsLimit: 1, // !
|
||||
MessagesLimit: 100,
|
||||
MessagesExpiryDuration: time.Hour,
|
||||
AttachmentExpiryDuration: time.Hour,
|
||||
AttachmentFileSizeLimit: 1000000,
|
||||
AttachmentTotalSizeLimit: 1000000,
|
||||
}))
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
Code: "pro",
|
||||
StripePriceID: "price_1111", // !
|
||||
ReservationsLimit: 3, // !
|
||||
MessagesLimit: 200,
|
||||
MessagesExpiryDuration: time.Hour,
|
||||
AttachmentExpiryDuration: time.Hour,
|
||||
AttachmentFileSizeLimit: 1000000,
|
||||
AttachmentTotalSizeLimit: 1000000,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
require.Nil(t, s.userManager.ReserveAccess("phil", "atopic", user.PermissionDenyAll))
|
||||
require.Nil(t, s.userManager.ReserveAccess("phil", "ztopic", user.PermissionDenyAll))
|
||||
|
||||
// Add billing details
|
||||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
billing := &user.Billing{
|
||||
StripeCustomerID: "acct_5555",
|
||||
StripeSubscriptionID: "sub_1234",
|
||||
StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue,
|
||||
StripeSubscriptionPaidUntil: time.Unix(123, 0),
|
||||
StripeSubscriptionCancelAt: time.Unix(456, 0),
|
||||
}
|
||||
require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
|
||||
|
||||
// Add some messages to "atopic" and "ztopic", everything in "ztopic" will be deleted
|
||||
rr := request(t, s, "PUT", "/atopic", "some aaa message", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
rr = request(t, s, "PUT", "/atopic", strings.Repeat("a", 5000), map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
a2 := toMessage(t, rr.Body.String())
|
||||
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
|
||||
|
||||
rr = request(t, s, "PUT", "/ztopic", "some zzz message", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
rr = request(t, s, "PUT", "/ztopic", strings.Repeat("z", 5000), map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
z2 := toMessage(t, rr.Body.String())
|
||||
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
|
||||
|
||||
// Call the webhook: This does all the magic
|
||||
rr = request(t, s, "POST", "/v1/account/billing/webhook", "dummy", map[string]string{
|
||||
"Stripe-Signature": "stripe signature",
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Verify that database columns were updated
|
||||
u, err = s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "starter", u.Tier.Code) // Not "pro"
|
||||
require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
|
||||
require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
|
||||
require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) // Not "past_due"
|
||||
require.Equal(t, int64(1674268231), u.Billing.StripeSubscriptionPaidUntil.Unix()) // Updated
|
||||
require.Equal(t, int64(1674299999), u.Billing.StripeSubscriptionCancelAt.Unix()) // Updated
|
||||
|
||||
// Verify that reservations were deleted
|
||||
r, err := s.userManager.Reservations("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(r)) // "ztopic" reservation was deleted
|
||||
require.Equal(t, "atopic", r[0].Topic)
|
||||
|
||||
// Verify that messages and attachments were deleted
|
||||
time.Sleep(time.Second)
|
||||
s.execManager()
|
||||
|
||||
ms, err := s.messageCache.Messages("atopic", sinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(ms))
|
||||
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
|
||||
|
||||
ms, err = s.messageCache.Messages("ztopic", sinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, len(ms))
|
||||
require.NoFileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
|
||||
}
|
||||
|
||||
type testStripeAPI struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
@ -175,3 +304,34 @@ func (s *testStripeAPI) ConstructWebhookEvent(payload []byte, header string, sec
|
|||
}
|
||||
|
||||
var _ stripeAPI = (*testStripeAPI)(nil)
|
||||
|
||||
func jsonToStripeEvent(t *testing.T, v string) stripe.Event {
|
||||
var e stripe.Event
|
||||
if err := json.Unmarshal([]byte(v), &e); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
const subscriptionUpdatedEventJSON = `
|
||||
{
|
||||
"type": "customer.subscription.updated",
|
||||
"data": {
|
||||
"object": {
|
||||
"id": "sub_1234",
|
||||
"customer": "acct_5555",
|
||||
"status": "active",
|
||||
"current_period_end": 1674268231,
|
||||
"cancel_at": 1674299999,
|
||||
"items": {
|
||||
"data": [
|
||||
{
|
||||
"price": {
|
||||
"id": "price_1234"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
|
|
@ -640,7 +640,7 @@ func TestServer_Auth_Success_User(t *testing.T) {
|
|||
s := newTestServer(t, c)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AllowAccess("", "ben", "mytopic", true, true))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "mytopic", user.PermissionReadWrite))
|
||||
|
||||
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
|
@ -654,8 +654,8 @@ func TestServer_Auth_Success_User_MultipleTopics(t *testing.T) {
|
|||
s := newTestServer(t, c)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AllowAccess("", "ben", "mytopic", true, true))
|
||||
require.Nil(t, s.userManager.AllowAccess("", "ben", "anothertopic", true, true))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "mytopic", user.PermissionReadWrite))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "anothertopic", user.PermissionReadWrite))
|
||||
|
||||
response := request(t, s, "GET", "/mytopic,anothertopic/auth", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
|
@ -688,7 +688,7 @@ func TestServer_Auth_Fail_Unauthorized(t *testing.T) {
|
|||
s := newTestServer(t, c)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AllowAccess("", "ben", "sometopic", true, true)) // Not mytopic!
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "sometopic", user.PermissionReadWrite)) // Not mytopic!
|
||||
|
||||
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
|
@ -702,8 +702,8 @@ func TestServer_Auth_Fail_CannotPublish(t *testing.T) {
|
|||
s := newTestServer(t, c)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test"))
|
||||
require.Nil(t, s.userManager.AllowAccess("", user.Everyone, "private", false, false))
|
||||
require.Nil(t, s.userManager.AllowAccess("", user.Everyone, "announcements", true, false))
|
||||
require.Nil(t, s.userManager.AllowAccess(user.Everyone, "private", user.PermissionDenyAll))
|
||||
require.Nil(t, s.userManager.AllowAccess(user.Everyone, "announcements", user.PermissionRead))
|
||||
|
||||
response := request(t, s, "PUT", "/mytopic", "test", nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
|
@ -750,7 +750,7 @@ func TestServer_StatsResetter(t *testing.T) {
|
|||
go s.runStatsResetter()
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||||
require.Nil(t, s.userManager.AllowAccess("", "phil", "mytopic", true, true))
|
||||
require.Nil(t, s.userManager.AllowAccess("phil", "mytopic", user.PermissionReadWrite))
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
response := request(t, s, "PUT", "/mytopic", "test", map[string]string{
|
||||
|
|
|
@ -16,6 +16,10 @@ const (
|
|||
// has to be very high to prevent e-mail abuse, but it doesn't really affect the other limits anyway, since
|
||||
// they are replenished faster (typically).
|
||||
visitorExpungeAfter = 24 * time.Hour
|
||||
|
||||
// visitorDefaultReservationsLimit is the amount of topic names a user without a tier is allowed to reserve.
|
||||
// This number is zero, and changing it may have unintended consequences in the web app, or otherwise
|
||||
visitorDefaultReservationsLimit = int64(0)
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -289,7 +293,7 @@ func defaultVisitorLimits(conf *Config) *visitorLimits {
|
|||
MessagesLimit: replenishDurationToDailyLimit(conf.VisitorRequestLimitReplenish),
|
||||
MessagesExpiryDuration: conf.CacheDuration,
|
||||
EmailsLimit: replenishDurationToDailyLimit(conf.VisitorEmailLimitReplenish),
|
||||
ReservationsLimit: 0, // No reservations for anonymous users, or users without a tier
|
||||
ReservationsLimit: visitorDefaultReservationsLimit,
|
||||
AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit,
|
||||
AttachmentFileSizeLimit: conf.AttachmentFileSizeLimit,
|
||||
AttachmentExpiryDuration: conf.AttachmentExpiryDuration,
|
||||
|
|
|
@ -219,7 +219,6 @@ const (
|
|||
INSERT INTO tier (code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
selectTierIDQuery = `SELECT id FROM tier WHERE code = ?`
|
||||
selectTiersQuery = `
|
||||
SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
|
||||
FROM tier
|
||||
|
@ -234,7 +233,7 @@ const (
|
|||
FROM tier
|
||||
WHERE stripe_price_id = ?
|
||||
`
|
||||
updateUserTierQuery = `UPDATE user SET tier_id = ? WHERE user = ?`
|
||||
updateUserTierQuery = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?`
|
||||
deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
|
||||
|
||||
updateBillingQuery = `
|
||||
|
@ -772,26 +771,47 @@ func (a *Manager) ChangeRole(username string, role Role) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// ChangeTier changes a user's tier using the tier code
|
||||
// ChangeTier changes a user's tier using the tier code. This function does not delete reservations, messages,
|
||||
// or attachments, even if the new tier has lower limits in this regard. That has to be done elsewhere.
|
||||
func (a *Manager) ChangeTier(username, tier string) error {
|
||||
if !AllowedUsername(username) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
rows, err := a.db.Query(selectTierIDQuery, tier)
|
||||
t, err := a.Tier(tier)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if err := a.checkReservationsLimit(username, t.ReservationsLimit); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := a.db.Exec(updateUserTierQuery, tier, username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetTier removes the tier from the given user
|
||||
func (a *Manager) ResetTier(username string) error {
|
||||
if !AllowedUsername(username) && username != Everyone && username != "" {
|
||||
return ErrInvalidArgument
|
||||
} else if err := a.checkReservationsLimit(username, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := a.db.Exec(deleteUserTierQuery, username)
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *Manager) checkReservationsLimit(username string, reservationsLimit int64) error {
|
||||
u, err := a.User(username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
var tierID int64
|
||||
if err := rows.Scan(&tierID); err != nil {
|
||||
if u.Tier != nil && reservationsLimit < u.Tier.ReservationsLimit {
|
||||
reservations, err := a.Reservations(username)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if int64(len(reservations)) > reservationsLimit {
|
||||
return ErrTooManyReservations
|
||||
}
|
||||
rows.Close()
|
||||
if _, err := a.db.Exec(updateUserTierQuery, tierID, username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -823,20 +843,37 @@ func (a *Manager) CheckAllowAccess(username string, topic string) error {
|
|||
// AllowAccess adds or updates an entry in th access control list for a specific user. It controls
|
||||
// read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
|
||||
// owner may either be a user (username), or the system (empty).
|
||||
func (a *Manager) AllowAccess(owner, username string, topicPattern string, read bool, write bool) error {
|
||||
func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error {
|
||||
if !AllowedUsername(username) && username != Everyone {
|
||||
return ErrInvalidArgument
|
||||
} else if owner != "" && !AllowedUsername(owner) {
|
||||
return ErrInvalidArgument
|
||||
} else if !AllowedTopicPattern(topicPattern) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
if _, err := a.db.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), read, write, owner, owner); err != nil {
|
||||
owner := ""
|
||||
if _, err := a.db.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), permission.IsRead(), permission.IsWrite(), owner, owner); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Manager) ReserveAccess(username string, topic string, everyone Permission) error {
|
||||
if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
tx, err := a.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(upsertUserAccessQuery, username, topic, true, true, username, username); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(upsertUserAccessQuery, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, username); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
|
||||
// empty) for an entire user. The parameter topicPattern may include wildcards (*).
|
||||
func (a *Manager) ResetAccess(username string, topicPattern string) error {
|
||||
|
@ -856,13 +893,29 @@ func (a *Manager) ResetAccess(username string, topicPattern string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// ResetTier removes the tier from the given user
|
||||
func (a *Manager) ResetTier(username string) error {
|
||||
if !AllowedUsername(username) && username != Everyone && username != "" {
|
||||
func (a *Manager) RemoveReservations(username string, topics ...string) error {
|
||||
if !AllowedUsername(username) || username == Everyone || len(topics) == 0 {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
_, err := a.db.Exec(deleteUserTierQuery, username)
|
||||
for _, topic := range topics {
|
||||
if !AllowedTopic(topic) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
}
|
||||
tx, err := a.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
for _, topic := range topics {
|
||||
if _, err := tx.Exec(deleteTopicAccessQuery, username, username, topic); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(deleteTopicAccessQuery, Everyone, Everyone, topic); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// DefaultAccess returns the default read/write access if no access control entry matches
|
||||
|
@ -879,8 +932,8 @@ func (a *Manager) CreateTier(tier *Tier) error {
|
|||
}
|
||||
|
||||
// ChangeBilling updates a user's billing fields, namely the Stripe customer ID, and subscription information
|
||||
func (a *Manager) ChangeBilling(user *User) error {
|
||||
if _, err := a.db.Exec(updateBillingQuery, nullString(user.Billing.StripeCustomerID), nullString(user.Billing.StripeSubscriptionID), nullString(string(user.Billing.StripeSubscriptionStatus)), nullInt64(user.Billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(user.Billing.StripeSubscriptionCancelAt.Unix()), user.Name); err != nil {
|
||||
func (a *Manager) ChangeBilling(username string, billing *Billing) error {
|
||||
if _, err := a.db.Exec(updateBillingQuery, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -15,13 +15,13 @@ func TestManager_FullScenario_Default_DenyAll(t *testing.T) {
|
|||
a := newTestManager(t, PermissionDenyAll)
|
||||
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, "unit-test"))
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
|
||||
require.Nil(t, a.AllowAccess("", "ben", "mytopic", true, true))
|
||||
require.Nil(t, a.AllowAccess("", "ben", "readme", true, false))
|
||||
require.Nil(t, a.AllowAccess("", "ben", "writeme", false, true))
|
||||
require.Nil(t, a.AllowAccess("", "ben", "everyonewrite", false, false)) // How unfair!
|
||||
require.Nil(t, a.AllowAccess("", Everyone, "announcements", true, false))
|
||||
require.Nil(t, a.AllowAccess("", Everyone, "everyonewrite", true, true))
|
||||
require.Nil(t, a.AllowAccess("", Everyone, "up*", false, true)) // Everyone can write to /up*
|
||||
require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite))
|
||||
require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead))
|
||||
require.Nil(t, a.AllowAccess("ben", "writeme", PermissionWrite))
|
||||
require.Nil(t, a.AllowAccess("ben", "everyonewrite", PermissionDenyAll)) // How unfair!
|
||||
require.Nil(t, a.AllowAccess(Everyone, "announcements", PermissionRead))
|
||||
require.Nil(t, a.AllowAccess(Everyone, "everyonewrite", PermissionReadWrite))
|
||||
require.Nil(t, a.AllowAccess(Everyone, "up*", PermissionWrite)) // Everyone can write to /up*
|
||||
|
||||
phil, err := a.Authenticate("phil", "phil")
|
||||
require.Nil(t, err)
|
||||
|
@ -130,12 +130,12 @@ func TestManager_UserManagement(t *testing.T) {
|
|||
a := newTestManager(t, PermissionDenyAll)
|
||||
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, "unit-test"))
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
|
||||
require.Nil(t, a.AllowAccess("", "ben", "mytopic", true, true))
|
||||
require.Nil(t, a.AllowAccess("", "ben", "readme", true, false))
|
||||
require.Nil(t, a.AllowAccess("", "ben", "writeme", false, true))
|
||||
require.Nil(t, a.AllowAccess("", "ben", "everyonewrite", false, false)) // How unfair!
|
||||
require.Nil(t, a.AllowAccess("", Everyone, "announcements", true, false))
|
||||
require.Nil(t, a.AllowAccess("", Everyone, "everyonewrite", true, true))
|
||||
require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite))
|
||||
require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead))
|
||||
require.Nil(t, a.AllowAccess("ben", "writeme", PermissionWrite))
|
||||
require.Nil(t, a.AllowAccess("ben", "everyonewrite", PermissionDenyAll)) // How unfair!
|
||||
require.Nil(t, a.AllowAccess(Everyone, "announcements", PermissionRead))
|
||||
require.Nil(t, a.AllowAccess(Everyone, "everyonewrite", PermissionReadWrite))
|
||||
|
||||
// Query user details
|
||||
phil, err := a.User("phil")
|
||||
|
@ -177,9 +177,9 @@ func TestManager_UserManagement(t *testing.T) {
|
|||
}, everyoneGrants)
|
||||
|
||||
// Ben: Before revoking
|
||||
require.Nil(t, a.AllowAccess("", "ben", "mytopic", true, true)) // Overwrite!
|
||||
require.Nil(t, a.AllowAccess("", "ben", "readme", true, false))
|
||||
require.Nil(t, a.AllowAccess("", "ben", "writeme", false, true))
|
||||
require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite)) // Overwrite!
|
||||
require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead))
|
||||
require.Nil(t, a.AllowAccess("ben", "writeme", PermissionWrite))
|
||||
require.Nil(t, a.Authorize(ben, "mytopic", PermissionRead))
|
||||
require.Nil(t, a.Authorize(ben, "mytopic", PermissionWrite))
|
||||
require.Nil(t, a.Authorize(ben, "readme", PermissionRead))
|
||||
|
@ -234,8 +234,8 @@ func TestManager_ChangePassword(t *testing.T) {
|
|||
func TestManager_ChangeRole(t *testing.T) {
|
||||
a := newTestManager(t, PermissionDenyAll)
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
|
||||
require.Nil(t, a.AllowAccess("", "ben", "mytopic", true, true))
|
||||
require.Nil(t, a.AllowAccess("", "ben", "readme", true, false))
|
||||
require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite))
|
||||
require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead))
|
||||
|
||||
ben, err := a.User("ben")
|
||||
require.Nil(t, err)
|
||||
|
@ -256,6 +256,28 @@ func TestManager_ChangeRole(t *testing.T) {
|
|||
require.Equal(t, 0, len(benGrants))
|
||||
}
|
||||
|
||||
func TestManager_Reservations(t *testing.T) {
|
||||
a := newTestManager(t, PermissionDenyAll)
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
|
||||
require.Nil(t, a.ReserveAccess("ben", "ztopic", PermissionDenyAll))
|
||||
require.Nil(t, a.ReserveAccess("ben", "readme", PermissionRead))
|
||||
require.Nil(t, a.AllowAccess("ben", "something-else", PermissionRead))
|
||||
|
||||
reservations, err := a.Reservations("ben")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(reservations))
|
||||
require.Equal(t, Reservation{
|
||||
Topic: "readme",
|
||||
Owner: PermissionReadWrite,
|
||||
Everyone: PermissionRead,
|
||||
}, reservations[0])
|
||||
require.Equal(t, Reservation{
|
||||
Topic: "ztopic",
|
||||
Owner: PermissionReadWrite,
|
||||
Everyone: PermissionDenyAll,
|
||||
}, reservations[1])
|
||||
}
|
||||
|
||||
func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
|
||||
a := newTestManager(t, PermissionDenyAll)
|
||||
require.Nil(t, a.CreateTier(&Tier{
|
||||
|
@ -272,8 +294,7 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
|
|||
}))
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
|
||||
require.Nil(t, a.ChangeTier("ben", "pro"))
|
||||
require.Nil(t, a.AllowAccess("ben", "ben", "mytopic", true, true))
|
||||
require.Nil(t, a.AllowAccess("ben", Everyone, "mytopic", false, false))
|
||||
require.Nil(t, a.ReserveAccess("ben", "mytopic", PermissionDenyAll))
|
||||
|
||||
ben, err := a.User("ben")
|
||||
require.Nil(t, err)
|
||||
|
@ -298,6 +319,13 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
|
|||
require.Equal(t, 1, len(everyoneGrants))
|
||||
require.Equal(t, PermissionDenyAll, everyoneGrants[0].Allow)
|
||||
|
||||
benReservations, err := a.Reservations("ben")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(benReservations))
|
||||
require.Equal(t, "mytopic", benReservations[0].Topic)
|
||||
require.Equal(t, PermissionReadWrite, benReservations[0].Owner)
|
||||
require.Equal(t, PermissionDenyAll, benReservations[0].Everyone)
|
||||
|
||||
// Switch to admin, this should remove all grants and owned ACL entries
|
||||
require.Nil(t, a.ChangeRole("ben", RoleAdmin))
|
||||
|
||||
|
|
|
@ -226,14 +226,5 @@ var (
|
|||
ErrInvalidArgument = errors.New("invalid argument")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrTierNotFound = errors.New("tier not found")
|
||||
)
|
||||
|
||||
// BillingStatus represents the status of a Stripe subscription
|
||||
type BillingStatus string
|
||||
|
||||
// BillingStatus values, subset of https://stripe.com/docs/billing/subscriptions/overview
|
||||
const (
|
||||
BillingStatusIncomplete = BillingStatus("incomplete")
|
||||
BillingStatusActive = BillingStatus("active")
|
||||
BillingStatusPastDue = BillingStatus("past_due")
|
||||
ErrTooManyReservations = errors.New("new tier has lower reservation limit")
|
||||
)
|
||||
|
|
|
@ -201,15 +201,19 @@
|
|||
"account_delete_dialog_label": "Type '{{username}}' to delete account",
|
||||
"account_delete_dialog_button_cancel": "Cancel",
|
||||
"account_delete_dialog_button_submit": "Permanently delete account",
|
||||
"account_delete_dialog_billing_warning": "Deleting your account also cancels your billing subscription immediately. You will not have access to the billing dashboard anymore.",
|
||||
"account_upgrade_dialog_title": "Change account tier",
|
||||
"account_upgrade_dialog_cancel_warning": "This will <strong>cancel your subscription</strong>, and downgrade your account on {{date}}. On that date, topic reservations as well as messages cached on the server <strong>will be deleted</strong>.",
|
||||
"account_upgrade_dialog_proration_info": "<strong>Proration</strong>: When switching between paid plans, the price difference will be charged or refunded in the next invoice. You will not receive another invoice until the end of the next billing period.",
|
||||
"account_upgrade_dialog_reservations_warning_one": "The selected tier allows fewer reserved topics than your current tier. Before changing your tier, <strong>please delete at least one reservation</strong>. You can remove reservations in the <Link>Settings</Link>.",
|
||||
"account_upgrade_dialog_reservations_warning_other": "The selected tier allows fewer reserved topics than your current tier. Before changing your tier, <strong>please delete at least {{count}} reservations</strong>. You can remove reservations in the <Link>Settings</Link>.",
|
||||
"account_upgrade_dialog_tier_features_reservations": "{{reservations}} reserved topics",
|
||||
"account_upgrade_dialog_tier_features_messages": "{{messages}} daily messages",
|
||||
"account_upgrade_dialog_tier_features_emails": "{{emails}} daily emails",
|
||||
"account_upgrade_dialog_tier_features_attachment_file_size": "{{filesize}} per file",
|
||||
"account_upgrade_dialog_tier_features_attachment_total_size": "{{totalsize}} total storage",
|
||||
"account_upgrade_dialog_tier_selected_label": "Selected",
|
||||
"account_upgrade_dialog_tier_current_label": "Current",
|
||||
"account_upgrade_dialog_button_cancel": "Cancel",
|
||||
"account_upgrade_dialog_button_redirect_signup": "Sign up now",
|
||||
"account_upgrade_dialog_button_pay_now": "Pay now and subscribe",
|
||||
|
|
|
@ -264,7 +264,6 @@ const AccountType = () => {
|
|||
const Stats = () => {
|
||||
const { t } = useTranslation();
|
||||
const { account } = useContext(AccountContext);
|
||||
const [upgradeDialogOpen, setUpgradeDialogOpen] = useState(false);
|
||||
|
||||
if (!account) {
|
||||
return <></>;
|
||||
|
@ -435,6 +434,7 @@ const DeleteAccount = () => {
|
|||
|
||||
const DeleteAccountDialog = (props) => {
|
||||
const { t } = useTranslation();
|
||||
const { account } = useContext(AccountContext);
|
||||
const [username, setUsername] = useState("");
|
||||
const fullScreen = useMediaQuery(theme.breakpoints.down('sm'));
|
||||
const buttonEnabled = username === session.username();
|
||||
|
@ -456,6 +456,9 @@ const DeleteAccountDialog = (props) => {
|
|||
fullWidth
|
||||
variant="standard"
|
||||
/>
|
||||
{account?.billing?.subscription &&
|
||||
<Alert severity="warning" sx={{mt: 1}}>{t("account_delete_dialog_billing_warning")}</Alert>
|
||||
}
|
||||
</DialogContent>
|
||||
<DialogActions>
|
||||
<Button onClick={props.onCancel}>{t("account_delete_dialog_button_cancel")}</Button>
|
||||
|
|
|
@ -3,7 +3,7 @@ import {useContext, useEffect, useState} from 'react';
|
|||
import {
|
||||
Alert,
|
||||
CardActions,
|
||||
CardContent,
|
||||
CardContent, Chip,
|
||||
FormControl,
|
||||
Select,
|
||||
Stack,
|
||||
|
@ -20,6 +20,7 @@ import prefs from "../app/Prefs";
|
|||
import {Paragraph} from "./styles";
|
||||
import EditIcon from '@mui/icons-material/Edit';
|
||||
import CloseIcon from "@mui/icons-material/Close";
|
||||
import WarningIcon from '@mui/icons-material/Warning';
|
||||
import IconButton from "@mui/material/IconButton";
|
||||
import PlayArrowIcon from '@mui/icons-material/PlayArrow';
|
||||
import Container from "@mui/material/Container";
|
||||
|
@ -41,10 +42,12 @@ import routes from "./routes";
|
|||
import accountApi, {UnauthorizedError} from "../app/AccountApi";
|
||||
import {Pref, PrefGroup} from "./Pref";
|
||||
import LockIcon from "@mui/icons-material/Lock";
|
||||
import {Public, PublicOff} from "@mui/icons-material";
|
||||
import {Check, Info, Public, PublicOff} from "@mui/icons-material";
|
||||
import DialogContentText from "@mui/material/DialogContentText";
|
||||
import ReserveTopicSelect from "./ReserveTopicSelect";
|
||||
import {AccountContext} from "./App";
|
||||
import {useOutletContext} from "react-router-dom";
|
||||
import subscriptionManager from "../app/SubscriptionManager";
|
||||
|
||||
const Preferences = () => {
|
||||
return (
|
||||
|
@ -543,6 +546,12 @@ const ReservationsTable = (props) => {
|
|||
const [dialogKey, setDialogKey] = useState(0);
|
||||
const [dialogOpen, setDialogOpen] = useState(false);
|
||||
const [dialogReservation, setDialogReservation] = useState(null);
|
||||
const { subscriptions } = useOutletContext();
|
||||
const localSubscriptions = Object.assign(
|
||||
...subscriptions
|
||||
.filter(s => s.baseUrl === config.base_url)
|
||||
.map(s => ({[s.topic]: s}))
|
||||
);
|
||||
|
||||
const handleEditClick = (reservation) => {
|
||||
setDialogKey(prev => prev+1);
|
||||
|
@ -592,7 +601,9 @@ const ReservationsTable = (props) => {
|
|||
key={reservation.topic}
|
||||
sx={{'&:last-child td, &:last-child th': {border: 0}}}
|
||||
>
|
||||
<TableCell component="th" scope="row" sx={{paddingLeft: 0}} aria-label={t("prefs_reservations_table_topic_header")}>{reservation.topic}</TableCell>
|
||||
<TableCell component="th" scope="row" sx={{paddingLeft: 0}} aria-label={t("prefs_reservations_table_topic_header")}>
|
||||
{reservation.topic}
|
||||
</TableCell>
|
||||
<TableCell aria-label={t("prefs_reservations_table_access_header")}>
|
||||
{reservation.everyone === "read-write" &&
|
||||
<>
|
||||
|
@ -620,6 +631,9 @@ const ReservationsTable = (props) => {
|
|||
}
|
||||
</TableCell>
|
||||
<TableCell align="right">
|
||||
{!localSubscriptions[reservation.topic] &&
|
||||
<Chip icon={<Info/>} label="Not subscribed" color="primary" variant="outlined"/>
|
||||
}
|
||||
<IconButton onClick={() => handleEditClick(reservation)} aria-label={t("prefs_reservations_edit_button")}>
|
||||
<EditIcon/>
|
||||
</IconButton>
|
||||
|
|
|
@ -21,13 +21,14 @@ import {Check} from "@mui/icons-material";
|
|||
import ListItemIcon from "@mui/material/ListItemIcon";
|
||||
import ListItemText from "@mui/material/ListItemText";
|
||||
import Box from "@mui/material/Box";
|
||||
import {NavLink} from "react-router-dom";
|
||||
|
||||
const UpgradeDialog = (props) => {
|
||||
const { t } = useTranslation();
|
||||
const { account } = useContext(AccountContext); // May be undefined!
|
||||
const fullScreen = useMediaQuery(theme.breakpoints.down('sm'));
|
||||
const [tiers, setTiers] = useState(null);
|
||||
const [newTier, setNewTier] = useState(account?.tier?.code); // May be undefined
|
||||
const [newTierCode, setNewTierCode] = useState(account?.tier?.code); // May be undefined
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [errorText, setErrorText] = useState("");
|
||||
|
||||
|
@ -41,47 +42,56 @@ const UpgradeDialog = (props) => {
|
|||
return <></>;
|
||||
}
|
||||
|
||||
const currentTier = account?.tier?.code; // May be undefined
|
||||
let action, submitButtonLabel, submitButtonEnabled;
|
||||
const tiersMap = Object.assign(...tiers.map(tier => ({[tier.code]: tier})));
|
||||
const newTier = tiersMap[newTierCode]; // May be undefined
|
||||
const currentTier = account?.tier; // May be undefined
|
||||
const currentTierCode = currentTier?.code; // May be undefined
|
||||
|
||||
// Figure out buttons, labels and the submit action
|
||||
let submitAction, submitButtonLabel, banner;
|
||||
if (!account) {
|
||||
submitButtonLabel = t("account_upgrade_dialog_button_redirect_signup");
|
||||
submitButtonEnabled = true;
|
||||
action = Action.REDIRECT_SIGNUP;
|
||||
} else if (currentTier === newTier) {
|
||||
submitAction = Action.REDIRECT_SIGNUP;
|
||||
banner = null;
|
||||
} else if (currentTierCode === newTierCode) {
|
||||
submitButtonLabel = t("account_upgrade_dialog_button_update_subscription");
|
||||
submitButtonEnabled = false;
|
||||
action = null;
|
||||
} else if (!currentTier) {
|
||||
submitAction = null;
|
||||
banner = (currentTierCode) ? Banner.PRORATION_INFO : null;
|
||||
} else if (!currentTierCode) {
|
||||
submitButtonLabel = t("account_upgrade_dialog_button_pay_now");
|
||||
submitButtonEnabled = true;
|
||||
action = Action.CREATE_SUBSCRIPTION;
|
||||
} else if (!newTier) {
|
||||
submitAction = Action.CREATE_SUBSCRIPTION;
|
||||
banner = null;
|
||||
} else if (!newTierCode) {
|
||||
submitButtonLabel = t("account_upgrade_dialog_button_cancel_subscription");
|
||||
submitButtonEnabled = true;
|
||||
action = Action.CANCEL_SUBSCRIPTION;
|
||||
submitAction = Action.CANCEL_SUBSCRIPTION;
|
||||
banner = Banner.CANCEL_WARNING;
|
||||
} else {
|
||||
submitButtonLabel = t("account_upgrade_dialog_button_update_subscription");
|
||||
submitButtonEnabled = true;
|
||||
action = Action.UPDATE_SUBSCRIPTION;
|
||||
submitAction = Action.UPDATE_SUBSCRIPTION;
|
||||
banner = Banner.PRORATION_INFO;
|
||||
}
|
||||
|
||||
// Exceptional conditions
|
||||
if (loading) {
|
||||
submitButtonEnabled = false;
|
||||
submitAction = null;
|
||||
} else if (newTier?.code && account?.reservations.length > newTier?.limits.reservations) {
|
||||
submitAction = null;
|
||||
banner = Banner.RESERVATIONS_WARNING;
|
||||
}
|
||||
|
||||
const handleSubmit = async () => {
|
||||
if (action === Action.REDIRECT_SIGNUP) {
|
||||
if (submitAction === Action.REDIRECT_SIGNUP) {
|
||||
window.location.href = routes.signup;
|
||||
return;
|
||||
}
|
||||
try {
|
||||
setLoading(true);
|
||||
if (action === Action.CREATE_SUBSCRIPTION) {
|
||||
const response = await accountApi.createBillingSubscription(newTier);
|
||||
if (submitAction === Action.CREATE_SUBSCRIPTION) {
|
||||
const response = await accountApi.createBillingSubscription(newTierCode);
|
||||
window.location.href = response.redirect_url;
|
||||
} else if (action === Action.UPDATE_SUBSCRIPTION) {
|
||||
await accountApi.updateBillingSubscription(newTier);
|
||||
} else if (action === Action.CANCEL_SUBSCRIPTION) {
|
||||
} else if (submitAction === Action.UPDATE_SUBSCRIPTION) {
|
||||
await accountApi.updateBillingSubscription(newTierCode);
|
||||
} else if (submitAction === Action.CANCEL_SUBSCRIPTION) {
|
||||
await accountApi.deleteBillingSubscription();
|
||||
}
|
||||
props.onCancel();
|
||||
|
@ -116,27 +126,39 @@ const UpgradeDialog = (props) => {
|
|||
<TierCard
|
||||
key={`tierCard${tier.code || '_free'}`}
|
||||
tier={tier}
|
||||
selected={newTier === tier.code} // tier.code may be undefined!
|
||||
onClick={() => setNewTier(tier.code)} // tier.code may be undefined!
|
||||
current={currentTierCode === tier.code} // tier.code or currentTierCode may be undefined!
|
||||
selected={newTierCode === tier.code} // tier.code may be undefined!
|
||||
onClick={() => setNewTierCode(tier.code)} // tier.code may be undefined!
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
{action === Action.CANCEL_SUBSCRIPTION &&
|
||||
{banner === Banner.CANCEL_WARNING &&
|
||||
<Alert severity="warning">
|
||||
<Trans
|
||||
i18nKey="account_upgrade_dialog_cancel_warning"
|
||||
values={{ date: formatShortDate(account?.billing?.paid_until || 0) }} />
|
||||
</Alert>
|
||||
}
|
||||
{currentTier && (!action || action === Action.UPDATE_SUBSCRIPTION) &&
|
||||
{banner === Banner.PRORATION_INFO &&
|
||||
<Alert severity="info">
|
||||
<Trans i18nKey="account_upgrade_dialog_proration_info" />
|
||||
</Alert>
|
||||
}
|
||||
{banner === Banner.RESERVATIONS_WARNING &&
|
||||
<Alert severity="warning">
|
||||
<Trans
|
||||
i18nKey="account_upgrade_dialog_reservations_warning"
|
||||
count={account?.reservations.length - newTier?.limits.reservations}
|
||||
components={{
|
||||
Link: <NavLink to={routes.settings}/>,
|
||||
}}
|
||||
/>
|
||||
</Alert>
|
||||
}
|
||||
</DialogContent>
|
||||
<DialogFooter status={errorText}>
|
||||
<Button onClick={props.onCancel}>{t("account_upgrade_dialog_button_cancel")}</Button>
|
||||
<Button onClick={handleSubmit} disabled={!submitButtonEnabled}>{submitButtonLabel}</Button>
|
||||
<Button onClick={handleSubmit} disabled={!submitAction}>{submitButtonLabel}</Button>
|
||||
</DialogFooter>
|
||||
</Dialog>
|
||||
);
|
||||
|
@ -144,8 +166,19 @@ const UpgradeDialog = (props) => {
|
|||
|
||||
const TierCard = (props) => {
|
||||
const { t } = useTranslation();
|
||||
const cardStyle = (props.selected) ? { background: "#eee", border: "2px solid #338574" } : { border: "2px solid transparent" };
|
||||
const tier = props.tier;
|
||||
let cardStyle, labelStyle, labelText;
|
||||
if (props.selected) {
|
||||
cardStyle = { background: "#eee", border: "2px solid #338574" };
|
||||
labelStyle = { background: "#338574", color: "white" };
|
||||
labelText = t("account_upgrade_dialog_tier_selected_label");
|
||||
} else if (props.current) {
|
||||
cardStyle = { border: "2px solid #eee" };
|
||||
labelStyle = { background: "#eee", color: "black" };
|
||||
labelText = t("account_upgrade_dialog_tier_current_label");
|
||||
} else {
|
||||
cardStyle = { border: "2px solid transparent" };
|
||||
}
|
||||
|
||||
return (
|
||||
<Box sx={{
|
||||
|
@ -163,16 +196,15 @@ const TierCard = (props) => {
|
|||
<Card sx={{ height: "100%" }}>
|
||||
<CardActionArea sx={{ height: "100%" }}>
|
||||
<CardContent onClick={props.onClick} sx={{ height: "100%" }}>
|
||||
{props.selected &&
|
||||
{labelStyle &&
|
||||
<div style={{
|
||||
position: "absolute",
|
||||
top: "0",
|
||||
right: "15px",
|
||||
padding: "2px 10px",
|
||||
background: "#338574",
|
||||
color: "white",
|
||||
borderRadius: "3px",
|
||||
}}>{t("account_upgrade_dialog_tier_selected_label")}</div>
|
||||
...labelStyle
|
||||
}}>{labelText}</div>
|
||||
}
|
||||
<Typography variant="h5" component="div">
|
||||
{tier.name || t("account_usage_tier_free")}
|
||||
|
@ -217,10 +249,17 @@ const FeatureItem = (props) => {
|
|||
};
|
||||
|
||||
const Action = {
|
||||
REDIRECT_SIGNUP: 0,
|
||||
CREATE_SUBSCRIPTION: 1,
|
||||
UPDATE_SUBSCRIPTION: 2,
|
||||
CANCEL_SUBSCRIPTION: 3
|
||||
REDIRECT_SIGNUP: 1,
|
||||
CREATE_SUBSCRIPTION: 2,
|
||||
UPDATE_SUBSCRIPTION: 3,
|
||||
CANCEL_SUBSCRIPTION: 4
|
||||
};
|
||||
|
||||
const Banner = {
|
||||
CANCEL_WARNING: 1,
|
||||
PRORATION_INFO: 2,
|
||||
RESERVATIONS_WARNING: 3
|
||||
};
|
||||
|
||||
|
||||
export default UpgradeDialog;
|
||||
|
|
Loading…
Reference in a new issue