1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2024-11-21 19:03:26 +01:00

Support for annual billing intervals

This commit is contained in:
binwiederhier 2023-02-21 22:44:30 -05:00
parent 07afaf961d
commit ef9d6d9f6c
17 changed files with 453 additions and 183 deletions

View file

@ -54,7 +54,8 @@ var cmdTier = &cli.Command{
&cli.StringFlag{Name: "attachment-total-size-limit", Value: defaultAttachmentTotalSizeLimit, Usage: "total size limit of attachments for the user"},
&cli.DurationFlag{Name: "attachment-expiry-duration", Value: defaultAttachmentExpiryDuration, Usage: "duration after which attachments are deleted"},
&cli.StringFlag{Name: "attachment-bandwidth-limit", Value: defaultAttachmentBandwidthLimit, Usage: "daily bandwidth limit for attachment uploads/downloads"},
&cli.StringFlag{Name: "stripe-price-id", Usage: "Stripe price ID for paid tiers (e.g. price_12345)"},
&cli.StringFlag{Name: "stripe-monthly-price-id", Usage: "Monthly Stripe price ID for paid tiers (e.g. price_12345)"},
&cli.StringFlag{Name: "stripe-yearly-price-id", Usage: "Yearly Stripe price ID for paid tiers (e.g. price_12345)"},
&cli.BoolFlag{Name: "ignore-exists", Usage: "if the tier already exists, perform no action and exit"},
},
Description: `Add a new tier to the ntfy user database.
@ -96,7 +97,8 @@ Examples:
&cli.StringFlag{Name: "attachment-total-size-limit", Usage: "total size limit of attachments for the user"},
&cli.DurationFlag{Name: "attachment-expiry-duration", Usage: "duration after which attachments are deleted"},
&cli.StringFlag{Name: "attachment-bandwidth-limit", Usage: "daily bandwidth limit for attachment uploads/downloads"},
&cli.StringFlag{Name: "stripe-price-id", Usage: "Stripe price ID for paid tiers (e.g. price_12345)"},
&cli.StringFlag{Name: "stripe-monthly-price-id", Usage: "Monthly Stripe price ID for paid tiers (e.g. price_12345)"},
&cli.StringFlag{Name: "stripe-yearly-price-id", Usage: "Yearly Stripe price ID for paid tiers (e.g. price_12345)"},
},
Description: `Updates a tier to change the limits.
@ -110,7 +112,8 @@ Examples:
ntfy tier change --name="Pro" pro # Update the name of an existing tier
ntfy tier change \ # Update multiple limits and fields
--message-expiry-duration=24h \
--stripe-price-id=price_1234 \
--stripe-monthly-price-id=price_1234 \
--stripe-monthly-price-id=price_5678 \
pro
`,
},
@ -166,6 +169,10 @@ func execTierAdd(c *cli.Context) error {
return errors.New("tier code expected, type 'ntfy tier add --help' for help")
} else if !user.AllowedTier(code) {
return errors.New("tier code must consist only of numbers and letters")
} else if c.String("stripe-monthly-price-id") != "" && c.String("stripe-yearly-price-id") == "" {
return errors.New("if stripe-monthly-price-id is set, stripe-yearly-price-id must also be set")
} else if c.String("stripe-monthly-price-id") == "" && c.String("stripe-yearly-price-id") != "" {
return errors.New("if stripe-yearly-price-id is set, stripe-monthly-price-id must also be set")
}
manager, err := createUserManager(c)
if err != nil {
@ -206,7 +213,8 @@ func execTierAdd(c *cli.Context) error {
AttachmentTotalSizeLimit: attachmentTotalSizeLimit,
AttachmentExpiryDuration: c.Duration("attachment-expiry-duration"),
AttachmentBandwidthLimit: attachmentBandwidthLimit,
StripePriceID: c.String("stripe-price-id"),
StripeMonthlyPriceID: c.String("stripe-monthly-price-id"),
StripeYearlyPriceID: c.String("stripe-yearly-price-id"),
}
if err := manager.AddTier(tier); err != nil {
return err
@ -273,8 +281,16 @@ func execTierChange(c *cli.Context) error {
return err
}
}
if c.IsSet("stripe-price-id") {
tier.StripePriceID = c.String("stripe-price-id")
if c.IsSet("stripe-monthly-price-id") {
tier.StripeMonthlyPriceID = c.String("stripe-monthly-price-id")
}
if c.IsSet("stripe-yearly-price-id") {
tier.StripeYearlyPriceID = c.String("stripe-yearly-price-id")
}
if tier.StripeMonthlyPriceID != "" && tier.StripeYearlyPriceID == "" {
return errors.New("if stripe-monthly-price-id is set, stripe-yearly-price-id must also be set")
} else if tier.StripeMonthlyPriceID == "" && tier.StripeYearlyPriceID != "" {
return errors.New("if stripe-yearly-price-id is set, stripe-monthly-price-id must also be set")
}
if err := manager.UpdateTier(tier); err != nil {
return err
@ -319,9 +335,9 @@ func execTierList(c *cli.Context) error {
}
func printTier(c *cli.Context, tier *user.Tier) {
stripePriceID := tier.StripePriceID
if stripePriceID == "" {
stripePriceID = "(none)"
prices := "(none)"
if tier.StripeMonthlyPriceID != "" && tier.StripeYearlyPriceID != "" {
prices = fmt.Sprintf("%s / %s", tier.StripeMonthlyPriceID, tier.StripeYearlyPriceID)
}
fmt.Fprintf(c.App.ErrWriter, "tier %s (id: %s)\n", tier.Code, tier.ID)
fmt.Fprintf(c.App.ErrWriter, "- Name: %s\n", tier.Name)
@ -333,5 +349,5 @@ func printTier(c *cli.Context, tier *user.Tier) {
fmt.Fprintf(c.App.ErrWriter, "- Attachment total size limit: %s\n", util.FormatSize(tier.AttachmentTotalSizeLimit))
fmt.Fprintf(c.App.ErrWriter, "- Attachment expiry duration: %s (%d seconds)\n", tier.AttachmentExpiryDuration.String(), int64(tier.AttachmentExpiryDuration.Seconds()))
fmt.Fprintf(c.App.ErrWriter, "- Attachment daily bandwidth limit: %s\n", util.FormatSize(tier.AttachmentBandwidthLimit))
fmt.Fprintf(c.App.ErrWriter, "- Stripe price: %s\n", stripePriceID)
fmt.Fprintf(c.App.ErrWriter, "- Stripe prices (monthly/yearly): %s\n", prices)
}

View file

@ -36,7 +36,8 @@ func TestCLI_Tier_AddListChangeDelete(t *testing.T) {
"--attachment-expiry-duration=7h",
"--attachment-total-size-limit=10G",
"--attachment-bandwidth-limit=100G",
"--stripe-price-id=price_991",
"--stripe-monthly-price-id=price_991",
"--stripe-yearly-price-id=price_992",
"pro",
))
require.Contains(t, stderr.String(), "- Message limit: 999")
@ -46,7 +47,7 @@ func TestCLI_Tier_AddListChangeDelete(t *testing.T) {
require.Contains(t, stderr.String(), "- Attachment file size limit: 100.0 MB")
require.Contains(t, stderr.String(), "- Attachment expiry duration: 7h")
require.Contains(t, stderr.String(), "- Attachment total size limit: 10.0 GB")
require.Contains(t, stderr.String(), "- Stripe price: price_991")
require.Contains(t, stderr.String(), "- Stripe prices (monthly/yearly): price_991 / price_992")
app, _, _, stderr = newTestApp()
require.Nil(t, runTierCommand(app, conf, "remove", "pro"))

View file

@ -8,6 +8,7 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release
* Support for publishing to protected topics via email with access tokens ([#612](https://github.com/binwiederhier/ntfy/pull/621), thanks to [@tamcore](https://github.com/tamcore))
* Support for base64-encoded and nested multipart emails ([#610](https://github.com/binwiederhier/ntfy/issues/610), thanks to [@Robert-litts](https://github.com/Robert-litts))
* Add support for annual billing intervals (no ticket)
**Bug fixes + maintenance:**

View file

@ -45,11 +45,11 @@ type Server struct {
visitors map[string]*visitor // ip:<ip> or user:<user>
firebaseClient *firebaseClient
messages int64
userManager *user.Manager // Might be nil!
messageCache *messageCache // Database that stores the messages
fileCache *fileCache // File system based cache that stores attachments
stripe stripeAPI // Stripe API, can be replaced with a mock
priceCache *util.LookupCache[map[string]string] // Stripe price ID -> formatted price
userManager *user.Manager // Might be nil!
messageCache *messageCache // Database that stores the messages
fileCache *fileCache // File system based cache that stores attachments
stripe stripeAPI // Stripe API, can be replaced with a mock
priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
closeChan chan bool
mu sync.Mutex
}

View file

@ -100,6 +100,7 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, r *http.Request, v *vis
Customer: true,
Subscription: u.Billing.StripeSubscriptionID != "",
Status: string(u.Billing.StripeSubscriptionStatus),
Interval: string(u.Billing.StripeSubscriptionInterval),
PaidUntil: u.Billing.StripeSubscriptionPaidUntil.Unix(),
CancelAt: u.Billing.StripeSubscriptionCancelAt.Unix(),
}

View file

@ -80,14 +80,17 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
return err
}
for _, tier := range tiers {
priceStr, ok := prices[tier.StripePriceID]
if tier.StripePriceID == "" || !ok {
priceMonth, priceYear := prices[tier.StripeMonthlyPriceID], prices[tier.StripeYearlyPriceID]
if priceMonth == 0 || priceYear == 0 { // Only allow tiers that have both prices!
continue
}
response = append(response, &apiAccountBillingTier{
Code: tier.Code,
Name: tier.Name,
Price: priceStr,
Code: tier.Code,
Name: tier.Name,
Prices: &apiAccountBillingPrices{
Month: priceMonth,
Year: priceYear,
},
Limits: &apiAccountLimits{
Basis: string(visitorLimitBasisTier),
Messages: tier.MessageLimit,
@ -117,11 +120,21 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
tier, err := s.userManager.Tier(req.Tier)
if err != nil {
return err
} else if tier.StripePriceID == "" {
}
var priceID string
if req.Interval == string(stripe.PriceRecurringIntervalMonth) && tier.StripeMonthlyPriceID != "" {
priceID = tier.StripeMonthlyPriceID
} else if req.Interval == string(stripe.PriceRecurringIntervalYear) && tier.StripeYearlyPriceID != "" {
priceID = tier.StripeYearlyPriceID
} else {
return errNotAPaidTier
}
logvr(v, r).
With(tier).
Fields(log.Context{
"stripe_price_id": priceID,
"stripe_subscription_interval": req.Interval,
}).
Tag(tagStripe).
Info("Creating Stripe checkout flow")
var stripeCustomerID *string
@ -143,7 +156,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
AllowPromotionCodes: stripe.Bool(true),
LineItems: []*stripe.CheckoutSessionLineItemParams{
{
Price: stripe.String(tier.StripePriceID),
Price: stripe.String(priceID),
Quantity: stripe.Int64(1),
},
},
@ -180,10 +193,11 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
sub, err := s.stripe.GetSubscription(sess.Subscription.ID)
if err != nil {
return err
} else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil {
} else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil || sub.Items.Data[0].Price.Recurring == nil {
return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "more than one line item in existing subscription")
}
tier, err := s.userManager.TierByStripePrice(sub.Items.Data[0].Price.ID)
priceID, interval := sub.Items.Data[0].Price.ID, sub.Items.Data[0].Price.Recurring.Interval
tier, err := s.userManager.TierByStripePrice(priceID)
if err != nil {
return err
}
@ -197,8 +211,10 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
Tag(tagStripe).
Fields(log.Context{
"stripe_customer_id": sess.Customer.ID,
"stripe_price_id": priceID,
"stripe_subscription_id": sub.ID,
"stripe_subscription_status": string(sub.Status),
"stripe_subscription_interval": string(interval),
"stripe_subscription_paid_until": sub.CurrentPeriodEnd,
}).
Info("Stripe checkout flow succeeded, updating user tier and subscription")
@ -213,7 +229,7 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
if _, err := s.stripe.UpdateCustomer(sess.Customer.ID, customerParams); err != nil {
return err
}
if err := s.updateSubscriptionAndTier(r, v, u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
if err := s.updateSubscriptionAndTier(r, v, u, tier, sess.Customer.ID, sub.ID, string(sub.Status), string(interval), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
return err
}
http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther)
@ -235,15 +251,24 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
if err != nil {
return err
}
var priceID string
if req.Interval == string(stripe.PriceRecurringIntervalMonth) && tier.StripeMonthlyPriceID != "" {
priceID = tier.StripeMonthlyPriceID
} else if req.Interval == string(stripe.PriceRecurringIntervalYear) && tier.StripeYearlyPriceID != "" {
priceID = tier.StripeYearlyPriceID
} else {
return errNotAPaidTier
}
logvr(v, r).
Tag(tagStripe).
Fields(log.Context{
"new_tier_id": tier.ID,
"new_tier_name": tier.Name,
"new_tier_stripe_price_id": tier.StripePriceID,
"new_tier_id": tier.ID,
"new_tier_code": tier.Code,
"new_tier_stripe_price_id": priceID,
"new_tier_stripe_subscription_interval": req.Interval,
// Other stripe_* fields filled by visitor context
}).
Info("Changing Stripe subscription and billing tier to %s/%s (price %s)", tier.ID, tier.Name, tier.StripePriceID)
Info("Changing Stripe subscription and billing tier to %s/%s (price %s, %s)", tier.ID, tier.Name, priceID, req.Interval)
sub, err := s.stripe.GetSubscription(u.Billing.StripeSubscriptionID)
if err != nil {
return err
@ -252,11 +277,11 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
}
params := &stripe.SubscriptionParams{
CancelAtPeriodEnd: stripe.Bool(false),
ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)),
ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorAlwaysInvoice)),
Items: []*stripe.SubscriptionItemsParams{
{
ID: stripe.String(sub.Items.Data[0].ID),
Price: stripe.String(tier.StripePriceID),
Price: stripe.String(priceID),
},
},
}
@ -345,20 +370,22 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(r *http.Request,
ev, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event.Data.Raw)))
if err != nil {
return err
} else if ev.ID == "" || ev.Customer == "" || ev.Status == "" || ev.CurrentPeriodEnd == 0 || ev.Items == nil || len(ev.Items.Data) != 1 || ev.Items.Data[0].Price == nil || ev.Items.Data[0].Price.ID == "" {
} else if ev.ID == "" || ev.Customer == "" || ev.Status == "" || ev.CurrentPeriodEnd == 0 || ev.Items == nil || len(ev.Items.Data) != 1 || ev.Items.Data[0].Price == nil || ev.Items.Data[0].Price.ID == "" || ev.Items.Data[0].Price.Recurring == nil {
logvr(v, r).Tag(tagStripe).Field("stripe_request", fmt.Sprintf("%#v", ev)).Warn("Unexpected request from Stripe")
return errHTTPBadRequestBillingRequestInvalid
}
subscriptionID, priceID := ev.ID, ev.Items.Data[0].Price.ID
subscriptionID, priceID, interval := ev.ID, ev.Items.Data[0].Price.ID, ev.Items.Data[0].Price.Recurring.Interval
logvr(v, r).
Tag(tagStripe).
Fields(log.Context{
"stripe_webhook_type": event.Type,
"stripe_customer_id": ev.Customer,
"stripe_price_id": priceID,
"stripe_subscription_id": ev.ID,
"stripe_subscription_status": ev.Status,
"stripe_subscription_interval": interval,
"stripe_subscription_paid_until": ev.CurrentPeriodEnd,
"stripe_subscription_cancel_at": ev.CancelAt,
"stripe_price_id": priceID,
}).
Info("Updating subscription to status %s, with price %s", ev.Status, priceID)
userFn := func() (*user.User, error) {
@ -376,7 +403,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(r *http.Request,
if err != nil {
return err
}
if err := s.updateSubscriptionAndTier(r, v, u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
if err := s.updateSubscriptionAndTier(r, v, u, tier, ev.Customer, subscriptionID, ev.Status, string(interval), ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
return err
}
s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
@ -399,14 +426,14 @@ func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(r *http.Request,
Tag(tagStripe).
Field("stripe_webhook_type", event.Type).
Info("Subscription deleted, downgrading to unpaid tier")
if err := s.updateSubscriptionAndTier(r, v, u, nil, ev.Customer, "", "", 0, 0); err != nil {
if err := s.updateSubscriptionAndTier(r, v, u, nil, ev.Customer, "", "", "", 0, 0); err != nil {
return err
}
s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
return nil
}
func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.User, tier *user.Tier, customerID, subscriptionID, status, interval string, paidUntil, cancelAt int64) error {
reservationsLimit := visitorDefaultReservationsLimit
if tier != nil {
reservationsLimit = tier.ReservationLimit
@ -423,9 +450,8 @@ func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.
logvr(v, r).
Tag(tagStripe).
Fields(log.Context{
"new_tier_id": tier.ID,
"new_tier_name": tier.Name,
"new_tier_stripe_price_id": tier.StripePriceID,
"new_tier_id": tier.ID,
"new_tier_code": tier.Code,
}).
Info("Changing tier to tier %s (%s) for user %s", tier.ID, tier.Name, u.Name)
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
@ -437,6 +463,7 @@ func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.
StripeCustomerID: customerID,
StripeSubscriptionID: subscriptionID,
StripeSubscriptionStatus: stripe.SubscriptionStatus(status),
StripeSubscriptionInterval: stripe.PriceRecurringInterval(interval),
StripeSubscriptionPaidUntil: time.Unix(paidUntil, 0),
StripeSubscriptionCancelAt: time.Unix(cancelAt, 0),
}
@ -448,20 +475,16 @@ func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.
// fetchStripePrices contacts the Stripe API to retrieve all prices. This is used by the server to cache the prices
// in memory, and ultimately for the web app to display the price table.
func (s *Server) fetchStripePrices() (map[string]string, error) {
func (s *Server) fetchStripePrices() (map[string]int64, error) {
log.Debug("Caching prices from Stripe API")
priceMap := make(map[string]string)
priceMap := make(map[string]int64)
prices, err := s.stripe.ListPrices(&stripe.PriceListParams{Active: stripe.Bool(true)})
if err != nil {
log.Warn("Fetching Stripe prices failed: %s", err.Error())
return nil, err
}
for _, p := range prices {
if p.UnitAmount%100 == 0 {
priceMap[p.ID] = fmt.Sprintf("$%d", p.UnitAmount/100)
} else {
priceMap[p.ID] = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100)
}
priceMap[p.ID] = p.UnitAmount
log.Trace("- Caching price %s = %v", p.ID, priceMap[p.ID])
}
return priceMap, nil

View file

@ -37,7 +37,9 @@ func TestPayments_Tiers(t *testing.T) {
On("ListPrices", mock.Anything).
Return([]*stripe.Price{
{ID: "price_123", UnitAmount: 500},
{ID: "price_124", UnitAmount: 5000},
{ID: "price_456", UnitAmount: 1000},
{ID: "price_457", UnitAmount: 10000},
{ID: "price_999", UnitAmount: 9999},
}, nil)
@ -58,7 +60,8 @@ func TestPayments_Tiers(t *testing.T) {
AttachmentFileSizeLimit: 999,
AttachmentTotalSizeLimit: 888,
AttachmentExpiryDuration: time.Minute,
StripePriceID: "price_123",
StripeMonthlyPriceID: "price_123",
StripeYearlyPriceID: "price_124",
}))
require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_444",
@ -71,7 +74,8 @@ func TestPayments_Tiers(t *testing.T) {
AttachmentFileSizeLimit: 999111,
AttachmentTotalSizeLimit: 888111,
AttachmentExpiryDuration: time.Hour,
StripePriceID: "price_456",
StripeMonthlyPriceID: "price_456",
StripeYearlyPriceID: "price_457",
}))
response := request(t, s, "GET", "/v1/tiers", "", nil)
require.Equal(t, 200, response.Code)
@ -98,6 +102,8 @@ func TestPayments_Tiers(t *testing.T) {
require.Equal(t, "pro", tier.Code)
require.Equal(t, "Pro", tier.Name)
require.Equal(t, "tier", tier.Limits.Basis)
require.Equal(t, int64(500), tier.Prices.Month)
require.Equal(t, int64(5000), tier.Prices.Year)
require.Equal(t, int64(777), tier.Limits.Reservations)
require.Equal(t, int64(1000), tier.Limits.Messages)
require.Equal(t, int64(3600), tier.Limits.MessagesExpiryDuration)
@ -109,6 +115,8 @@ func TestPayments_Tiers(t *testing.T) {
tier = tiers[2]
require.Equal(t, "business", tier.Code)
require.Equal(t, "Business", tier.Name)
require.Equal(t, int64(1000), tier.Prices.Month)
require.Equal(t, int64(10000), tier.Prices.Year)
require.Equal(t, "tier", tier.Limits.Basis)
require.Equal(t, int64(777333), tier.Limits.Reservations)
require.Equal(t, int64(2000), tier.Limits.Messages)
@ -136,14 +144,14 @@ func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
// Create tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_123",
Code: "pro",
StripePriceID: "price_123",
ID: "ti_123",
Code: "pro",
StripeMonthlyPriceID: "price_123",
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
// Create subscription
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro", "interval": "month"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, response.Code)
@ -172,9 +180,9 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
// Create tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_123",
Code: "pro",
StripePriceID: "price_123",
ID: "ti_123",
Code: "pro",
StripeMonthlyPriceID: "price_123",
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
@ -187,7 +195,7 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
// Create subscription
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro", "interval": "month"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, response.Code)
@ -214,9 +222,9 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
// Create tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_123",
Code: "pro",
StripePriceID: "price_123",
ID: "ti_123",
Code: "pro",
StripeMonthlyPriceID: "price_123",
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
@ -267,7 +275,7 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_123",
Code: "starter",
StripePriceID: "price_1234",
StripeMonthlyPriceID: "price_1234",
ReservationLimit: 1,
MessageLimit: 220, // 220 * 5% = 11 requests before rate limiting kicks in
MessageExpiryDuration: time.Hour,
@ -298,7 +306,12 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
Items: &stripe.SubscriptionItemList{
Data: []*stripe.SubscriptionItem{
{
Price: &stripe.Price{ID: "price_1234"},
Price: &stripe.Price{
ID: "price_1234",
Recurring: &stripe.PriceRecurring{
Interval: stripe.PriceRecurringIntervalMonth,
},
},
},
},
},
@ -333,6 +346,7 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
require.Equal(t, "", u.Billing.StripeCustomerID)
require.Equal(t, "", u.Billing.StripeSubscriptionID)
require.Equal(t, stripe.SubscriptionStatus(""), u.Billing.StripeSubscriptionStatus)
require.Equal(t, stripe.PriceRecurringInterval(""), u.Billing.StripeSubscriptionInterval)
require.Equal(t, int64(0), u.Billing.StripeSubscriptionPaidUntil.Unix())
require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
require.Equal(t, int64(0), u.Stats.Messages) // Messages and emails are not persisted for no-tier users!
@ -349,6 +363,7 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus)
require.Equal(t, stripe.PriceRecurringIntervalMonth, u.Billing.StripeSubscriptionInterval)
require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix())
require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
require.Equal(t, int64(0), u.Stats.Messages)
@ -423,7 +438,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_1",
Code: "starter",
StripePriceID: "price_1234", // !
StripeMonthlyPriceID: "price_1234", // !
ReservationLimit: 1, // !
MessageLimit: 100,
MessageExpiryDuration: time.Hour,
@ -435,7 +450,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_2",
Code: "pro",
StripePriceID: "price_1111", // !
StripeMonthlyPriceID: "price_1111", // !
ReservationLimit: 3, // !
MessageLimit: 200,
MessageExpiryDuration: time.Hour,
@ -457,6 +472,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
StripeCustomerID: "acct_5555",
StripeSubscriptionID: "sub_1234",
StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue,
StripeSubscriptionInterval: stripe.PriceRecurringIntervalMonth,
StripeSubscriptionPaidUntil: time.Unix(123, 0),
StripeSubscriptionCancelAt: time.Unix(456, 0),
}
@ -499,9 +515,10 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
require.Equal(t, "starter", u.Tier.Code) // Not "pro"
require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) // Not "past_due"
require.Equal(t, int64(1674268231), u.Billing.StripeSubscriptionPaidUntil.Unix()) // Updated
require.Equal(t, int64(1674299999), u.Billing.StripeSubscriptionCancelAt.Unix()) // Updated
require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) // Not "past_due"
require.Equal(t, stripe.PriceRecurringIntervalYear, u.Billing.StripeSubscriptionInterval) // Not "month"
require.Equal(t, int64(1674268231), u.Billing.StripeSubscriptionPaidUntil.Unix()) // Updated
require.Equal(t, int64(1674299999), u.Billing.StripeSubscriptionCancelAt.Unix()) // Updated
// Verify that reservations were deleted
r, err := s.userManager.Reservations("phil")
@ -546,10 +563,10 @@ func TestPayments_Webhook_Subscription_Deleted(t *testing.T) {
// Create a user with a Stripe subscription and 3 reservations
require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_1",
Code: "pro",
StripePriceID: "price_1234",
ReservationLimit: 1,
ID: "ti_1",
Code: "pro",
StripeMonthlyPriceID: "price_1234",
ReservationLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
@ -562,6 +579,7 @@ func TestPayments_Webhook_Subscription_Deleted(t *testing.T) {
StripeCustomerID: "acct_5555",
StripeSubscriptionID: "sub_1234",
StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue,
StripeSubscriptionInterval: stripe.PriceRecurringIntervalMonth,
StripeSubscriptionPaidUntil: time.Unix(123, 0),
StripeSubscriptionCancelAt: time.Unix(0, 0),
}))
@ -615,11 +633,11 @@ func TestPayments_Subscription_Update_Different_Tier(t *testing.T) {
stripeMock.
On("UpdateSubscription", "sub_123", &stripe.SubscriptionParams{
CancelAtPeriodEnd: stripe.Bool(false),
ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)),
ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorAlwaysInvoice)),
Items: []*stripe.SubscriptionItemsParams{
{
ID: stripe.String("someid_123"),
Price: stripe.String("price_456"),
Price: stripe.String("price_457"),
},
},
}).
@ -627,14 +645,16 @@ func TestPayments_Subscription_Update_Different_Tier(t *testing.T) {
// Create tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_123",
Code: "pro",
StripePriceID: "price_123",
ID: "ti_123",
Code: "pro",
StripeMonthlyPriceID: "price_123",
StripeYearlyPriceID: "price_124",
}))
require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_456",
Code: "business",
StripePriceID: "price_456",
ID: "ti_456",
Code: "business",
StripeMonthlyPriceID: "price_456",
StripeYearlyPriceID: "price_457",
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
@ -644,7 +664,7 @@ func TestPayments_Subscription_Update_Different_Tier(t *testing.T) {
}))
// Call endpoint to change subscription
rr := request(t, s, "PUT", "/v1/account/billing/subscription", `{"tier":"business"}`, map[string]string{
rr := request(t, s, "PUT", "/v1/account/billing/subscription", `{"tier":"business","interval":"year"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
@ -795,7 +815,10 @@ const subscriptionUpdatedEventJSON = `
"data": [
{
"price": {
"id": "price_1234"
"id": "price_1234",
"recurring": {
"interval": "year"
}
}
}
]
@ -818,7 +841,10 @@ const subscriptionDeletedEventJSON = `
"data": [
{
"price": {
"id": "price_1234"
"id": "price_1234",
"recurring": {
"interval": "month"
}
}
}
]

View file

@ -309,6 +309,7 @@ type apiAccountBilling struct {
Customer bool `json:"customer"`
Subscription bool `json:"subscription"`
Status string `json:"status,omitempty"`
Interval string `json:"interval,omitempty"`
PaidUntil int64 `json:"paid_until,omitempty"`
CancelAt int64 `json:"cancel_at,omitempty"`
}
@ -343,11 +344,16 @@ type apiConfigResponse struct {
DisallowedTopics []string `json:"disallowed_topics"`
}
type apiAccountBillingPrices struct {
Month int64 `json:"month"`
Year int64 `json:"year"`
}
type apiAccountBillingTier struct {
Code string `json:"code,omitempty"`
Name string `json:"name,omitempty"`
Price string `json:"price,omitempty"`
Limits *apiAccountLimits `json:"limits"`
Code string `json:"code,omitempty"`
Name string `json:"name,omitempty"`
Prices *apiAccountBillingPrices `json:"prices,omitempty"`
Limits *apiAccountLimits `json:"limits"`
}
type apiAccountBillingSubscriptionCreateResponse struct {
@ -355,7 +361,8 @@ type apiAccountBillingSubscriptionCreateResponse struct {
}
type apiAccountBillingSubscriptionChangeRequest struct {
Tier string `json:"tier"`
Tier string `json:"tier"`
Interval string `json:"interval"`
}
type apiAccountBillingPortalRedirectResponse struct {
@ -385,7 +392,10 @@ type apiStripeSubscriptionUpdatedEvent struct {
Items *struct {
Data []*struct {
Price *struct {
ID string `json:"id"`
ID string `json:"id"`
Recurring *struct {
Interval string `json:"interval"`
} `json:"recurring"`
} `json:"price"`
} `json:"data"`
} `json:"items"`

View file

@ -46,7 +46,8 @@ var (
// Manager-related queries
const (
createTablesQueriesNoTx = `
createTablesQueries = `
BEGIN;
CREATE TABLE IF NOT EXISTS tier (
id TEXT PRIMARY KEY,
code TEXT NOT NULL,
@ -59,10 +60,12 @@ const (
attachment_total_size_limit INT NOT NULL,
attachment_expiry_duration INT NOT NULL,
attachment_bandwidth_limit INT NOT NULL,
stripe_price_id TEXT
stripe_monthly_price_id TEXT,
stripe_yearly_price_id TEXT
);
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
CREATE TABLE IF NOT EXISTS user (
id TEXT PRIMARY KEY,
tier_id TEXT,
@ -76,6 +79,7 @@ const (
stripe_customer_id TEXT,
stripe_subscription_id TEXT,
stripe_subscription_status TEXT,
stripe_subscription_interval TEXT,
stripe_subscription_paid_until INT,
stripe_subscription_cancel_at INT,
created INT NOT NULL,
@ -112,33 +116,33 @@ const (
INSERT INTO user (id, user, pass, role, sync_topic, created)
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', UNIXEPOCH())
ON CONFLICT (id) DO NOTHING;
COMMIT;
`
createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;`
builtinStartupQueries = `
PRAGMA foreign_keys = ON;
`
selectUserByIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.id = ?
`
selectUserByNameQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE user = ?
`
selectUserByTokenQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
JOIN user_token tk on u.id = tk.user_id
LEFT JOIN tier t on t.id = u.tier_id
WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?)
`
selectUserByStripeCustomerIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.stripe_customer_id = ?
@ -246,27 +250,27 @@ const (
`
insertTierQuery = `
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
updateTierQuery = `
UPDATE tier
SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_price_id = ?
SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_monthly_price_id = ?, stripe_yearly_price_id = ?
WHERE code = ?
`
selectTiersQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
FROM tier
`
selectTierByCodeQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
FROM tier
WHERE code = ?
`
selectTierByPriceIDQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
FROM tier
WHERE stripe_price_id = ?
WHERE (stripe_monthly_price_id = ? OR stripe_yearly_price_id = ?)
`
updateUserTierQuery = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?`
deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
@ -274,21 +278,86 @@ const (
updateBillingQuery = `
UPDATE user
SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ?
SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_interval = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ?
WHERE user = ?
`
)
// Schema management queries
const (
currentSchemaVersion = 2
currentSchemaVersion = 3
insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
// 1 -> 2 (complex migration!)
migrate1To2RenameUserTableQueryNoTx = `
migrate1To2CreateTablesQueries = `
ALTER TABLE user RENAME TO user_old;
CREATE TABLE IF NOT EXISTS tier (
id TEXT PRIMARY KEY,
code TEXT NOT NULL,
name TEXT NOT NULL,
messages_limit INT NOT NULL,
messages_expiry_duration INT NOT NULL,
emails_limit INT NOT NULL,
reservations_limit INT NOT NULL,
attachment_file_size_limit INT NOT NULL,
attachment_total_size_limit INT NOT NULL,
attachment_expiry_duration INT NOT NULL,
attachment_bandwidth_limit INT NOT NULL,
stripe_price_id TEXT
);
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
CREATE TABLE IF NOT EXISTS user (
id TEXT PRIMARY KEY,
tier_id TEXT,
user TEXT NOT NULL,
pass TEXT NOT NULL,
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
prefs JSON NOT NULL DEFAULT '{}',
sync_topic TEXT NOT NULL,
stats_messages INT NOT NULL DEFAULT (0),
stats_emails INT NOT NULL DEFAULT (0),
stripe_customer_id TEXT,
stripe_subscription_id TEXT,
stripe_subscription_status TEXT,
stripe_subscription_paid_until INT,
stripe_subscription_cancel_at INT,
created INT NOT NULL,
deleted INT,
FOREIGN KEY (tier_id) REFERENCES tier (id)
);
CREATE UNIQUE INDEX idx_user ON user (user);
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
CREATE TABLE IF NOT EXISTS user_access (
user_id TEXT NOT NULL,
topic TEXT NOT NULL,
read INT NOT NULL,
write INT NOT NULL,
owner_user_id INT,
PRIMARY KEY (user_id, topic),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS user_token (
user_id TEXT NOT NULL,
token TEXT NOT NULL,
label TEXT NOT NULL,
last_access INT NOT NULL,
last_origin TEXT NOT NULL,
expires INT NOT NULL,
PRIMARY KEY (user_id, token),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO user (id, user, pass, role, sync_topic, created)
VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH())
ON CONFLICT (id) DO NOTHING;
`
migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
migrate1To2InsertUserNoTx = `
@ -304,11 +373,22 @@ const (
DROP TABLE access;
DROP TABLE user_old;
`
// 2 -> 3
migrate2To3UpdateQueries = `
ALTER TABLE user ADD COLUMN stripe_subscription_interval TEXT;
ALTER TABLE tier RENAME COLUMN stripe_price_id TO stripe_monthly_price_id;
ALTER TABLE tier ADD COLUMN stripe_yearly_price_id TEXT;
DROP INDEX IF EXISTS idx_tier_price_id;
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
`
)
var (
migrations = map[int]func(db *sql.DB) error{
1: migrateFrom1,
2: migrateFrom2,
}
)
@ -805,13 +885,13 @@ func (a *Manager) userByToken(token string) (*User, error) {
func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
defer rows.Close()
var id, username, hash, role, prefs, syncTopic string
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierID, tierCode, tierName sql.NullString
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString
var messages, emails int64
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
if !rows.Next() {
return nil, ErrUserNotFound
}
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil {
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
@ -828,11 +908,12 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
Emails: emails,
},
Billing: &Billing{
StripeCustomerID: stripeCustomerID.String, // May be empty
StripeSubscriptionID: stripeSubscriptionID.String, // May be empty
StripeSubscriptionStatus: stripe.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty
StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero
StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero
StripeCustomerID: stripeCustomerID.String, // May be empty
StripeSubscriptionID: stripeSubscriptionID.String, // May be empty
StripeSubscriptionStatus: stripe.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty
StripeSubscriptionInterval: stripe.PriceRecurringInterval(stripeSubscriptionInterval.String), // May be empty
StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero
StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero
},
Deleted: deleted.Valid,
}
@ -853,7 +934,8 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
StripePriceID: stripePriceID.String, // May be empty
StripeMonthlyPriceID: stripeMonthlyPriceID.String, // May be empty
StripeYearlyPriceID: stripeYearlyPriceID.String, // May be empty
}
}
return user, nil
@ -1134,7 +1216,7 @@ func (a *Manager) AddTier(tier *Tier) error {
if tier.ID == "" {
tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
}
if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripePriceID)); err != nil {
if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID)); err != nil {
return err
}
return nil
@ -1142,7 +1224,7 @@ func (a *Manager) AddTier(tier *Tier) error {
// UpdateTier updates a tier's properties in the database
func (a *Manager) UpdateTier(tier *Tier) error {
if _, err := a.db.Exec(updateTierQuery, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripePriceID), tier.Code); err != nil {
if _, err := a.db.Exec(updateTierQuery, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID), tier.Code); err != nil {
return err
}
return nil
@ -1162,7 +1244,7 @@ func (a *Manager) RemoveTier(code string) error {
// ChangeBilling updates a user's billing fields, namely the Stripe customer ID, and subscription information
func (a *Manager) ChangeBilling(username string, billing *Billing) error {
if _, err := a.db.Exec(updateBillingQuery, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil {
if _, err := a.db.Exec(updateBillingQuery, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullString(string(billing.StripeSubscriptionInterval)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil {
return err
}
return nil
@ -1200,7 +1282,7 @@ func (a *Manager) Tier(code string) (*Tier, error) {
// TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
rows, err := a.db.Query(selectTierByPriceIDQuery, priceID)
rows, err := a.db.Query(selectTierByPriceIDQuery, priceID, priceID)
if err != nil {
return nil, err
}
@ -1210,12 +1292,12 @@ func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
var id, code, name string
var stripePriceID sql.NullString
var stripeMonthlyPriceID, stripeYearlyPriceID sql.NullString
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64
if !rows.Next() {
return nil, ErrTierNotFound
}
if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil {
if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
@ -1233,7 +1315,8 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
StripePriceID: stripePriceID.String, // May be empty
StripeMonthlyPriceID: stripeMonthlyPriceID.String, // May be empty
StripeYearlyPriceID: stripeYearlyPriceID.String, // May be empty
}, nil
}
@ -1313,10 +1396,7 @@ func migrateFrom1(db *sql.DB) error {
}
defer tx.Rollback()
// Rename user -> user_old, and create new tables
if _, err := tx.Exec(migrate1To2RenameUserTableQueryNoTx); err != nil {
return err
}
if _, err := tx.Exec(createTablesQueriesNoTx); err != nil {
if _, err := tx.Exec(migrate1To2CreateTablesQueries); err != nil {
return err
}
// Insert users from user_old into new user table, with ID and sync_topic
@ -1356,6 +1436,22 @@ func migrateFrom1(db *sql.DB) error {
return nil
}
func migrateFrom2(db *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 2 to 3")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate2To3UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 3); err != nil {
return err
}
return tx.Commit()
}
func nullString(s string) sql.NullString {
if s == "" {
return sql.NullString{}

View file

@ -4,6 +4,7 @@ import (
"database/sql"
"fmt"
"github.com/stretchr/testify/require"
"github.com/stripe/stripe-go/v74"
"golang.org/x/crypto/bcrypt"
"heckel.io/ntfy/util"
"net/netip"
@ -113,7 +114,8 @@ func TestManager_AddUser_And_Query(t *testing.T) {
require.Nil(t, a.ChangeBilling("user", &Billing{
StripeCustomerID: "acct_123",
StripeSubscriptionID: "sub_123",
StripeSubscriptionStatus: "active",
StripeSubscriptionStatus: stripe.SubscriptionStatusActive,
StripeSubscriptionInterval: stripe.PriceRecurringIntervalMonth,
StripeSubscriptionPaidUntil: time.Now().Add(time.Hour),
StripeSubscriptionCancelAt: time.Unix(0, 0),
}))
@ -395,7 +397,7 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
require.Nil(t, a.AddTier(&Tier{
Code: "pro",
Name: "ntfy Pro",
StripePriceID: "price123",
StripeMonthlyPriceID: "price123",
MessageLimit: 5_000,
MessageExpiryDuration: 3 * 24 * time.Hour,
EmailLimit: 50,
@ -761,7 +763,7 @@ func TestManager_Tier_Create_Update_List_Delete(t *testing.T) {
AttachmentTotalSizeLimit: 1,
AttachmentExpiryDuration: time.Second,
AttachmentBandwidthLimit: 1,
StripePriceID: "price_1",
StripeMonthlyPriceID: "price_1",
}))
require.Nil(t, a.AddTier(&Tier{
Code: "pro",
@ -774,7 +776,7 @@ func TestManager_Tier_Create_Update_List_Delete(t *testing.T) {
AttachmentTotalSizeLimit: 123123,
AttachmentExpiryDuration: 10800 * time.Second,
AttachmentBandwidthLimit: 21474836480,
StripePriceID: "price_2",
StripeMonthlyPriceID: "price_2",
}))
require.Nil(t, a.AddUser("phil", "phil", RoleUser))
require.Nil(t, a.ChangeTier("phil", "pro"))
@ -800,7 +802,7 @@ func TestManager_Tier_Create_Update_List_Delete(t *testing.T) {
require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit)
require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration)
require.Equal(t, int64(21474836480), ti.AttachmentBandwidthLimit)
require.Equal(t, "price_2", ti.StripePriceID)
require.Equal(t, "price_2", ti.StripeMonthlyPriceID)
// Update tier
ti.EmailLimit = 999999
@ -822,7 +824,7 @@ func TestManager_Tier_Create_Update_List_Delete(t *testing.T) {
require.Equal(t, int64(1), ti.AttachmentTotalSizeLimit)
require.Equal(t, time.Second, ti.AttachmentExpiryDuration)
require.Equal(t, int64(1), ti.AttachmentBandwidthLimit)
require.Equal(t, "price_1", ti.StripePriceID)
require.Equal(t, "price_1", ti.StripeMonthlyPriceID)
ti = tiers[1]
require.Equal(t, "pro", ti.Code)
@ -835,7 +837,7 @@ func TestManager_Tier_Create_Update_List_Delete(t *testing.T) {
require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit)
require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration)
require.Equal(t, int64(21474836480), ti.AttachmentBandwidthLimit)
require.Equal(t, "price_2", ti.StripePriceID)
require.Equal(t, "price_2", ti.StripeMonthlyPriceID)
ti, err = a.TierByStripePrice("price_1")
require.Nil(t, err)
@ -849,7 +851,7 @@ func TestManager_Tier_Create_Update_List_Delete(t *testing.T) {
require.Equal(t, int64(1), ti.AttachmentTotalSizeLimit)
require.Equal(t, time.Second, ti.AttachmentExpiryDuration)
require.Equal(t, int64(1), ti.AttachmentBandwidthLimit)
require.Equal(t, "price_1", ti.StripePriceID)
require.Equal(t, "price_1", ti.StripeMonthlyPriceID)
// Cannot remove tier, since user has this tier
require.Error(t, a.RemoveTier("pro"))

View file

@ -91,15 +91,17 @@ type Tier struct {
AttachmentTotalSizeLimit int64 // Total file size for all files of this user (bytes)
AttachmentExpiryDuration time.Duration // Duration after which attachments will be deleted
AttachmentBandwidthLimit int64 // Daily bandwidth limit for the user
StripePriceID string // Price ID for paid tiers (price_...)
StripeMonthlyPriceID string // Monthly price ID for paid tiers (price_...)
StripeYearlyPriceID string // Yearly price ID for paid tiers (price_...)
}
// Context returns fields for the log
func (t *Tier) Context() log.Context {
return log.Context{
"tier_id": t.ID,
"tier_code": t.Code,
"stripe_price_id": t.StripePriceID,
"tier_id": t.ID,
"tier_code": t.Code,
"stripe_monthly_price_id": t.StripeMonthlyPriceID,
"stripe_yearly_price_id": t.StripeYearlyPriceID,
}
}
@ -136,6 +138,7 @@ type Billing struct {
StripeCustomerID string
StripeSubscriptionID string
StripeSubscriptionStatus stripe.SubscriptionStatus
StripeSubscriptionInterval stripe.PriceRecurringInterval
StripeSubscriptionPaidUntil time.Time
StripeSubscriptionCancelAt time.Time
}

View file

@ -49,12 +49,15 @@ func TestAllowedTier(t *testing.T) {
func TestTierContext(t *testing.T) {
tier := &Tier{
ID: "ti_abc",
Code: "pro",
StripePriceID: "price_123",
ID: "ti_abc",
Code: "pro",
StripeMonthlyPriceID: "price_123",
StripeYearlyPriceID: "price_456",
}
context := tier.Context()
require.Equal(t, "ti_abc", context["tier_id"])
require.Equal(t, "pro", context["tier_code"])
require.Equal(t, "price_123", context["stripe_price_id"])
require.Equal(t, "price_123", context["stripe_monthly_price_id"])
require.Equal(t, "price_456", context["stripe_yearly_price_id"])
}

View file

@ -193,6 +193,8 @@
"account_basics_tier_admin_suffix_no_tier": "(no tier)",
"account_basics_tier_basic": "Basic",
"account_basics_tier_free": "Free",
"account_basics_tier_interval_monthly": "monthly",
"account_basics_tier_interval_yearly": "annually",
"account_basics_tier_upgrade_button": "Upgrade to Pro",
"account_basics_tier_change_button": "Change",
"account_basics_tier_paid_until": "Subscription paid until {{date}}, and will auto-renew",
@ -215,15 +217,21 @@
"account_delete_dialog_button_submit": "Permanently delete account",
"account_delete_dialog_billing_warning": "Deleting your account also cancels your billing subscription immediately. You will not have access to the billing dashboard anymore.",
"account_upgrade_dialog_title": "Change account tier",
"account_upgrade_dialog_interval_monthly": "Monthly",
"account_upgrade_dialog_interval_yearly": "Annually",
"account_upgrade_dialog_cancel_warning": "This will <strong>cancel your subscription</strong>, and downgrade your account on {{date}}. On that date, topic reservations as well as messages cached on the server <strong>will be deleted</strong>.",
"account_upgrade_dialog_proration_info": "<strong>Proration</strong>: When switching between paid plans, the price difference will be charged or refunded in the next invoice. You will not receive another invoice until the end of the next billing period.",
"account_upgrade_dialog_proration_info": "<strong>Proration</strong>: When upgrading between paid plans, the price difference will be <strong>charged immediately</strong>. When downgrading to a lower tier, the balance will be used to pay for future billing periods.",
"account_upgrade_dialog_reservations_warning_one": "The selected tier allows fewer reserved topics than your current tier. Before changing your tier, <strong>please delete at least one reservation</strong>. You can remove reservations in the <Link>Settings</Link>.",
"account_upgrade_dialog_reservations_warning_other": "The selected tier allows fewer reserved topics than your current tier. Before changing your tier, <strong>please delete at least {{count}} reservations</strong>. You can remove reservations in the <Link>Settings</Link>.",
"account_upgrade_dialog_tier_features_reservations": "{{reservations}} reserved topics",
"account_upgrade_dialog_tier_features_no_reservations": "No reserved topics",
"account_upgrade_dialog_tier_features_messages": "{{messages}} daily messages",
"account_upgrade_dialog_tier_features_emails": "{{emails}} daily emails",
"account_upgrade_dialog_tier_features_attachment_file_size": "{{filesize}} per file",
"account_upgrade_dialog_tier_features_attachment_total_size": "{{totalsize}} total storage",
"account_upgrade_dialog_tier_price_per_month": "month",
"account_upgrade_dialog_tier_price_billed_monthly": "{{price}} per year. Billed monthly.",
"account_upgrade_dialog_tier_price_billed_yearly": "{{price}} billed annually. Save {{save}}.",
"account_upgrade_dialog_tier_selected_label": "Selected",
"account_upgrade_dialog_tier_current_label": "Current",
"account_upgrade_dialog_button_cancel": "Cancel",

View file

@ -257,23 +257,24 @@ class AccountApi {
return this.tiers;
}
async createBillingSubscription(tier) {
console.log(`[AccountApi] Creating billing subscription with ${tier}`);
return await this.upsertBillingSubscription("POST", tier)
async createBillingSubscription(tier, interval) {
console.log(`[AccountApi] Creating billing subscription with ${tier} and interval ${interval}`);
return await this.upsertBillingSubscription("POST", tier, interval)
}
async updateBillingSubscription(tier) {
console.log(`[AccountApi] Updating billing subscription with ${tier}`);
return await this.upsertBillingSubscription("PUT", tier)
async updateBillingSubscription(tier, interval) {
console.log(`[AccountApi] Updating billing subscription with ${tier} and interval ${interval}`);
return await this.upsertBillingSubscription("PUT", tier, interval)
}
async upsertBillingSubscription(method, tier) {
async upsertBillingSubscription(method, tier, interval) {
const url = accountBillingSubscriptionUrl(config.base_url);
const response = await fetchOrThrow(url, {
method: method,
headers: withBearerAuth({}, session.token()),
body: JSON.stringify({
tier: tier
tier: tier,
interval: interval
})
});
return await response.json(); // May throw SyntaxError
@ -371,6 +372,12 @@ export const SubscriptionStatus = {
PAST_DUE: "past_due"
};
// Maps to stripe.PriceRecurringInterval
export const SubscriptionInterval = {
MONTH: "month",
YEAR: "year"
};
// Maps to user.Permission in user/types.go
export const Permission = {
READ_WRITE: "read-write",

View file

@ -212,6 +212,13 @@ export const formatNumber = (n) => {
return n;
}
export const formatPrice = (n) => {
if (n % 100 === 0) {
return `$${n/100}`;
}
return `$${(n/100).toPrecision(2)}`;
}
export const openUrl = (url) => {
window.open(url, "_blank", "noopener,noreferrer");
};

View file

@ -35,7 +35,7 @@ import TextField from "@mui/material/TextField";
import routes from "./routes";
import IconButton from "@mui/material/IconButton";
import {formatBytes, formatShortDate, formatShortDateTime, openUrl} from "../app/utils";
import accountApi, {LimitBasis, Role, SubscriptionStatus} from "../app/AccountApi";
import accountApi, {LimitBasis, Role, SubscriptionInterval, SubscriptionStatus} from "../app/AccountApi";
import InfoOutlinedIcon from '@mui/icons-material/InfoOutlined';
import {Pref, PrefGroup} from "./Pref";
import db from "../app/db";
@ -248,6 +248,11 @@ const AccountType = () => {
accountType = (config.enable_payments) ? t("account_basics_tier_free") : t("account_basics_tier_basic");
} else {
accountType = account.tier.name;
if (account.billing?.interval === SubscriptionInterval.MONTH) {
accountType += ` (${t("account_basics_tier_interval_monthly")})`;
} else if (account.billing?.interval === SubscriptionInterval.YEAR) {
accountType += ` (${t("account_basics_tier_interval_yearly")})`;
}
}
return (

View file

@ -3,20 +3,20 @@ import {useContext, useEffect, useState} from 'react';
import Dialog from '@mui/material/Dialog';
import DialogContent from '@mui/material/DialogContent';
import DialogTitle from '@mui/material/DialogTitle';
import {Alert, CardActionArea, CardContent, ListItem, useMediaQuery} from "@mui/material";
import {Alert, Badge, CardActionArea, CardContent, Chip, ListItem, Stack, Switch, useMediaQuery} from "@mui/material";
import theme from "./theme";
import DialogFooter from "./DialogFooter";
import Button from "@mui/material/Button";
import accountApi from "../app/AccountApi";
import accountApi, {SubscriptionInterval} from "../app/AccountApi";
import session from "../app/Session";
import routes from "./routes";
import Card from "@mui/material/Card";
import Typography from "@mui/material/Typography";
import {AccountContext} from "./App";
import {formatBytes, formatNumber, formatShortDate} from "../app/utils";
import {formatBytes, formatNumber, formatPrice, formatShortDate} from "../app/utils";
import {Trans, useTranslation} from "react-i18next";
import List from "@mui/material/List";
import {Check} from "@mui/icons-material";
import {Check, Close} from "@mui/icons-material";
import ListItemIcon from "@mui/material/ListItemIcon";
import ListItemText from "@mui/material/ListItemText";
import Box from "@mui/material/Box";
@ -28,6 +28,7 @@ const UpgradeDialog = (props) => {
const { account } = useContext(AccountContext); // May be undefined!
const [error, setError] = useState("");
const [tiers, setTiers] = useState(null);
const [interval, setInterval] = useState(account?.billing?.interval || SubscriptionInterval.YEAR);
const [newTierCode, setNewTierCode] = useState(account?.tier?.code); // May be undefined
const [loading, setLoading] = useState(false);
const fullScreen = useMediaQuery(theme.breakpoints.down('sm'));
@ -46,6 +47,7 @@ const UpgradeDialog = (props) => {
const tiersMap = Object.assign(...tiers.map(tier => ({[tier.code]: tier})));
const newTier = tiersMap[newTierCode]; // May be undefined
const currentTier = account?.tier; // May be undefined
const currentInterval = account?.billing?.interval; // May be undefined
const currentTierCode = currentTier?.code; // May be undefined
// Figure out buttons, labels and the submit action
@ -54,7 +56,7 @@ const UpgradeDialog = (props) => {
submitButtonLabel = t("account_upgrade_dialog_button_redirect_signup");
submitAction = Action.REDIRECT_SIGNUP;
banner = null;
} else if (currentTierCode === newTierCode) {
} else if (currentTierCode === newTierCode && currentInterval === interval) {
submitButtonLabel = t("account_upgrade_dialog_button_update_subscription");
submitAction = null;
banner = (currentTierCode) ? Banner.PRORATION_INFO : null;
@ -88,10 +90,10 @@ const UpgradeDialog = (props) => {
try {
setLoading(true);
if (submitAction === Action.CREATE_SUBSCRIPTION) {
const response = await accountApi.createBillingSubscription(newTierCode);
const response = await accountApi.createBillingSubscription(newTierCode, interval);
window.location.href = response.redirect_url;
} else if (submitAction === Action.UPDATE_SUBSCRIPTION) {
await accountApi.updateBillingSubscription(newTierCode);
await accountApi.updateBillingSubscription(newTierCode, interval);
} else if (submitAction === Action.CANCEL_SUBSCRIPTION) {
await accountApi.deleteBillingSubscription();
}
@ -108,15 +110,45 @@ const UpgradeDialog = (props) => {
}
}
// Figure out discount
let discount;
if (newTier?.prices) {
discount = Math.round(((newTier.prices.month*12/newTier.prices.year)-1)*100);
} else {
for (const t of tiers) {
if (t.prices) {
discount = Math.round(((t.prices.month*12/t.prices.year)-1)*100);
break;
}
}
}
return (
<Dialog
open={props.open}
onClose={props.onCancel}
maxWidth="md"
fullWidth
maxWidth="lg"
fullScreen={fullScreen}
>
<DialogTitle>{t("account_upgrade_dialog_title")}</DialogTitle>
<DialogTitle>
<div style={{ display: "flex", flexDirection: "row" }}>
<div style={{ flexGrow: 1 }}>{t("account_upgrade_dialog_title")}</div>
<div style={{
display: "flex",
flexDirection: "row",
alignItems: "center",
marginTop: "4px"
}}>
<Typography component="span" variant="subtitle1">{t("account_upgrade_dialog_interval_monthly")}</Typography>
<Switch
checked={interval === SubscriptionInterval.YEAR}
onChange={(ev) => setInterval(ev.target.checked ? SubscriptionInterval.YEAR : SubscriptionInterval.MONTH)}
/>
<Typography component="span" variant="subtitle1">{t("account_upgrade_dialog_interval_yearly")}</Typography>
{discount > 0 && <Chip label={`-${discount}%`} color="primary" size="small" sx={{ marginLeft: "5px" }}/>}
</div>
</div>
</DialogTitle>
<DialogContent>
<div style={{
display: "flex",
@ -130,24 +162,25 @@ const UpgradeDialog = (props) => {
tier={tier}
current={currentTierCode === tier.code} // tier.code or currentTierCode may be undefined!
selected={newTierCode === tier.code} // tier.code may be undefined!
interval={interval}
onClick={() => setNewTierCode(tier.code)} // tier.code may be undefined!
/>
)}
</div>
{banner === Banner.CANCEL_WARNING &&
<Alert severity="warning">
<Alert severity="warning" sx={{ fontSize: "1rem" }}>
<Trans
i18nKey="account_upgrade_dialog_cancel_warning"
values={{ date: formatShortDate(account?.billing?.paid_until || 0) }} />
</Alert>
}
{banner === Banner.PRORATION_INFO &&
<Alert severity="info">
<Alert severity="info" sx={{ fontSize: "1rem" }}>
<Trans i18nKey="account_upgrade_dialog_proration_info" />
</Alert>
}
{banner === Banner.RESERVATIONS_WARNING &&
<Alert severity="warning">
<Alert severity="warning" sx={{ fontSize: "1rem" }}>
<Trans
i18nKey="account_upgrade_dialog_reservations_warning"
count={account?.reservations.length - newTier?.limits.reservations}
@ -169,28 +202,37 @@ const UpgradeDialog = (props) => {
const TierCard = (props) => {
const { t } = useTranslation();
const tier = props.tier;
let cardStyle, labelStyle, labelText;
if (props.selected) {
cardStyle = { background: "#eee", border: "2px solid #338574" };
cardStyle = { background: "#eee", border: "3px solid #338574" };
labelStyle = { background: "#338574", color: "white" };
labelText = t("account_upgrade_dialog_tier_selected_label");
} else if (props.current) {
cardStyle = { border: "2px solid #eee" };
cardStyle = { border: "3px solid #eee" };
labelStyle = { background: "#eee", color: "black" };
labelText = t("account_upgrade_dialog_tier_current_label");
} else {
cardStyle = { border: "2px solid transparent" };
cardStyle = { border: "3px solid transparent" };
}
let monthlyPrice;
if (!tier.prices) {
monthlyPrice = 0;
} else if (props.interval === SubscriptionInterval.YEAR) {
monthlyPrice = tier.prices.year/12;
} else if (props.interval === SubscriptionInterval.MONTH) {
monthlyPrice = tier.prices.month;
}
return (
<Box sx={{
m: "7px",
minWidth: "190px",
maxWidth: "250px",
minWidth: "240px",
flexGrow: 1,
flexShrink: 1,
flexBasis: 0,
borderRadius: "3px",
borderRadius: "5px",
"&:first-of-type": { ml: 0 },
"&:last-of-type": { mr: 0 },
...cardStyle
@ -208,19 +250,29 @@ const TierCard = (props) => {
...labelStyle
}}>{labelText}</div>
}
<Typography variant="h5" component="div">
<Typography variant="subtitle1" component="div">
{tier.name || t("account_basics_tier_free")}
</Typography>
<div>
<Typography component="span" variant="h4" sx={{ fontWeight: 500, marginRight: "3px" }}>{formatPrice(monthlyPrice)}</Typography>
{monthlyPrice > 0 && <>/ {t("account_upgrade_dialog_tier_price_per_month")}</>}
</div>
<List dense>
{tier.limits.reservations > 0 && <FeatureItem>{t("account_upgrade_dialog_tier_features_reservations", { reservations: tier.limits.reservations })}</FeatureItem>}
<FeatureItem>{t("account_upgrade_dialog_tier_features_messages", { messages: formatNumber(tier.limits.messages) })}</FeatureItem>
<FeatureItem>{t("account_upgrade_dialog_tier_features_emails", { emails: formatNumber(tier.limits.emails) })}</FeatureItem>
<FeatureItem>{t("account_upgrade_dialog_tier_features_attachment_file_size", { filesize: formatBytes(tier.limits.attachment_file_size, 0) })}</FeatureItem>
<FeatureItem>{t("account_upgrade_dialog_tier_features_attachment_total_size", { totalsize: formatBytes(tier.limits.attachment_total_size, 0) })}</FeatureItem>
{tier.limits.reservations > 0 && <Feature>{t("account_upgrade_dialog_tier_features_reservations", { reservations: tier.limits.reservations })}</Feature>}
{tier.limits.reservations === 0 && <NoFeature>{t("account_upgrade_dialog_tier_features_no_reservations")}</NoFeature>}
<Feature>{t("account_upgrade_dialog_tier_features_messages", { messages: formatNumber(tier.limits.messages) })}</Feature>
<Feature>{t("account_upgrade_dialog_tier_features_emails", { emails: formatNumber(tier.limits.emails) })}</Feature>
<Feature>{t("account_upgrade_dialog_tier_features_attachment_file_size", { filesize: formatBytes(tier.limits.attachment_file_size, 0) })}</Feature>
<Feature>{t("account_upgrade_dialog_tier_features_attachment_total_size", { totalsize: formatBytes(tier.limits.attachment_total_size, 0) })}</Feature>
</List>
{tier.price &&
<Typography variant="subtitle1" sx={{fontWeight: 500}}>
{tier.price} / month
{tier.prices && props.interval === SubscriptionInterval.MONTH &&
<Typography variant="body2" color="gray">
{t("account_upgrade_dialog_tier_price_billed_monthly", { price: formatPrice(tier.prices.month*12) })}
</Typography>
}
{tier.prices && props.interval === SubscriptionInterval.YEAR &&
<Typography variant="body2" color="gray">
{t("account_upgrade_dialog_tier_price_billed_yearly", { price: formatPrice(tier.prices.year), save: formatPrice(tier.prices.month*12-tier.prices.year) })}
</Typography>
}
</CardContent>
@ -231,16 +283,25 @@ const TierCard = (props) => {
);
}
const Feature = (props) => {
return <FeatureItem feature={true}>{props.children}</FeatureItem>;
}
const NoFeature = (props) => {
return <FeatureItem feature={false}>{props.children}</FeatureItem>;
}
const FeatureItem = (props) => {
return (
<ListItem disableGutters sx={{m: 0, p: 0}}>
<ListItemIcon sx={{minWidth: "24px"}}>
<Check fontSize="small" sx={{ color: "#338574" }}/>
{props.feature && <Check fontSize="small" sx={{ color: "#338574" }}/>}
{!props.feature && <Close fontSize="small" sx={{ color: "gray" }}/>}
</ListItemIcon>
<ListItemText
sx={{mt: "2px", mb: "2px"}}
primary={
<Typography variant="body2">
<Typography variant="body1">
{props.children}
</Typography>
}