Support for annual billing intervals

This commit is contained in:
binwiederhier 2023-02-21 22:44:30 -05:00
parent 07afaf961d
commit ef9d6d9f6c
17 changed files with 453 additions and 183 deletions

View File

@ -54,7 +54,8 @@ var cmdTier = &cli.Command{
&cli.StringFlag{Name: "attachment-total-size-limit", Value: defaultAttachmentTotalSizeLimit, Usage: "total size limit of attachments for the user"}, &cli.StringFlag{Name: "attachment-total-size-limit", Value: defaultAttachmentTotalSizeLimit, Usage: "total size limit of attachments for the user"},
&cli.DurationFlag{Name: "attachment-expiry-duration", Value: defaultAttachmentExpiryDuration, Usage: "duration after which attachments are deleted"}, &cli.DurationFlag{Name: "attachment-expiry-duration", Value: defaultAttachmentExpiryDuration, Usage: "duration after which attachments are deleted"},
&cli.StringFlag{Name: "attachment-bandwidth-limit", Value: defaultAttachmentBandwidthLimit, Usage: "daily bandwidth limit for attachment uploads/downloads"}, &cli.StringFlag{Name: "attachment-bandwidth-limit", Value: defaultAttachmentBandwidthLimit, Usage: "daily bandwidth limit for attachment uploads/downloads"},
&cli.StringFlag{Name: "stripe-price-id", Usage: "Stripe price ID for paid tiers (e.g. price_12345)"}, &cli.StringFlag{Name: "stripe-monthly-price-id", Usage: "Monthly Stripe price ID for paid tiers (e.g. price_12345)"},
&cli.StringFlag{Name: "stripe-yearly-price-id", Usage: "Yearly Stripe price ID for paid tiers (e.g. price_12345)"},
&cli.BoolFlag{Name: "ignore-exists", Usage: "if the tier already exists, perform no action and exit"}, &cli.BoolFlag{Name: "ignore-exists", Usage: "if the tier already exists, perform no action and exit"},
}, },
Description: `Add a new tier to the ntfy user database. Description: `Add a new tier to the ntfy user database.
@ -96,7 +97,8 @@ Examples:
&cli.StringFlag{Name: "attachment-total-size-limit", Usage: "total size limit of attachments for the user"}, &cli.StringFlag{Name: "attachment-total-size-limit", Usage: "total size limit of attachments for the user"},
&cli.DurationFlag{Name: "attachment-expiry-duration", Usage: "duration after which attachments are deleted"}, &cli.DurationFlag{Name: "attachment-expiry-duration", Usage: "duration after which attachments are deleted"},
&cli.StringFlag{Name: "attachment-bandwidth-limit", Usage: "daily bandwidth limit for attachment uploads/downloads"}, &cli.StringFlag{Name: "attachment-bandwidth-limit", Usage: "daily bandwidth limit for attachment uploads/downloads"},
&cli.StringFlag{Name: "stripe-price-id", Usage: "Stripe price ID for paid tiers (e.g. price_12345)"}, &cli.StringFlag{Name: "stripe-monthly-price-id", Usage: "Monthly Stripe price ID for paid tiers (e.g. price_12345)"},
&cli.StringFlag{Name: "stripe-yearly-price-id", Usage: "Yearly Stripe price ID for paid tiers (e.g. price_12345)"},
}, },
Description: `Updates a tier to change the limits. Description: `Updates a tier to change the limits.
@ -110,7 +112,8 @@ Examples:
ntfy tier change --name="Pro" pro # Update the name of an existing tier ntfy tier change --name="Pro" pro # Update the name of an existing tier
ntfy tier change \ # Update multiple limits and fields ntfy tier change \ # Update multiple limits and fields
--message-expiry-duration=24h \ --message-expiry-duration=24h \
--stripe-price-id=price_1234 \ --stripe-monthly-price-id=price_1234 \
--stripe-monthly-price-id=price_5678 \
pro pro
`, `,
}, },
@ -166,6 +169,10 @@ func execTierAdd(c *cli.Context) error {
return errors.New("tier code expected, type 'ntfy tier add --help' for help") return errors.New("tier code expected, type 'ntfy tier add --help' for help")
} else if !user.AllowedTier(code) { } else if !user.AllowedTier(code) {
return errors.New("tier code must consist only of numbers and letters") return errors.New("tier code must consist only of numbers and letters")
} else if c.String("stripe-monthly-price-id") != "" && c.String("stripe-yearly-price-id") == "" {
return errors.New("if stripe-monthly-price-id is set, stripe-yearly-price-id must also be set")
} else if c.String("stripe-monthly-price-id") == "" && c.String("stripe-yearly-price-id") != "" {
return errors.New("if stripe-yearly-price-id is set, stripe-monthly-price-id must also be set")
} }
manager, err := createUserManager(c) manager, err := createUserManager(c)
if err != nil { if err != nil {
@ -206,7 +213,8 @@ func execTierAdd(c *cli.Context) error {
AttachmentTotalSizeLimit: attachmentTotalSizeLimit, AttachmentTotalSizeLimit: attachmentTotalSizeLimit,
AttachmentExpiryDuration: c.Duration("attachment-expiry-duration"), AttachmentExpiryDuration: c.Duration("attachment-expiry-duration"),
AttachmentBandwidthLimit: attachmentBandwidthLimit, AttachmentBandwidthLimit: attachmentBandwidthLimit,
StripePriceID: c.String("stripe-price-id"), StripeMonthlyPriceID: c.String("stripe-monthly-price-id"),
StripeYearlyPriceID: c.String("stripe-yearly-price-id"),
} }
if err := manager.AddTier(tier); err != nil { if err := manager.AddTier(tier); err != nil {
return err return err
@ -273,8 +281,16 @@ func execTierChange(c *cli.Context) error {
return err return err
} }
} }
if c.IsSet("stripe-price-id") { if c.IsSet("stripe-monthly-price-id") {
tier.StripePriceID = c.String("stripe-price-id") tier.StripeMonthlyPriceID = c.String("stripe-monthly-price-id")
}
if c.IsSet("stripe-yearly-price-id") {
tier.StripeYearlyPriceID = c.String("stripe-yearly-price-id")
}
if tier.StripeMonthlyPriceID != "" && tier.StripeYearlyPriceID == "" {
return errors.New("if stripe-monthly-price-id is set, stripe-yearly-price-id must also be set")
} else if tier.StripeMonthlyPriceID == "" && tier.StripeYearlyPriceID != "" {
return errors.New("if stripe-yearly-price-id is set, stripe-monthly-price-id must also be set")
} }
if err := manager.UpdateTier(tier); err != nil { if err := manager.UpdateTier(tier); err != nil {
return err return err
@ -319,9 +335,9 @@ func execTierList(c *cli.Context) error {
} }
func printTier(c *cli.Context, tier *user.Tier) { func printTier(c *cli.Context, tier *user.Tier) {
stripePriceID := tier.StripePriceID prices := "(none)"
if stripePriceID == "" { if tier.StripeMonthlyPriceID != "" && tier.StripeYearlyPriceID != "" {
stripePriceID = "(none)" prices = fmt.Sprintf("%s / %s", tier.StripeMonthlyPriceID, tier.StripeYearlyPriceID)
} }
fmt.Fprintf(c.App.ErrWriter, "tier %s (id: %s)\n", tier.Code, tier.ID) fmt.Fprintf(c.App.ErrWriter, "tier %s (id: %s)\n", tier.Code, tier.ID)
fmt.Fprintf(c.App.ErrWriter, "- Name: %s\n", tier.Name) fmt.Fprintf(c.App.ErrWriter, "- Name: %s\n", tier.Name)
@ -333,5 +349,5 @@ func printTier(c *cli.Context, tier *user.Tier) {
fmt.Fprintf(c.App.ErrWriter, "- Attachment total size limit: %s\n", util.FormatSize(tier.AttachmentTotalSizeLimit)) fmt.Fprintf(c.App.ErrWriter, "- Attachment total size limit: %s\n", util.FormatSize(tier.AttachmentTotalSizeLimit))
fmt.Fprintf(c.App.ErrWriter, "- Attachment expiry duration: %s (%d seconds)\n", tier.AttachmentExpiryDuration.String(), int64(tier.AttachmentExpiryDuration.Seconds())) fmt.Fprintf(c.App.ErrWriter, "- Attachment expiry duration: %s (%d seconds)\n", tier.AttachmentExpiryDuration.String(), int64(tier.AttachmentExpiryDuration.Seconds()))
fmt.Fprintf(c.App.ErrWriter, "- Attachment daily bandwidth limit: %s\n", util.FormatSize(tier.AttachmentBandwidthLimit)) fmt.Fprintf(c.App.ErrWriter, "- Attachment daily bandwidth limit: %s\n", util.FormatSize(tier.AttachmentBandwidthLimit))
fmt.Fprintf(c.App.ErrWriter, "- Stripe price: %s\n", stripePriceID) fmt.Fprintf(c.App.ErrWriter, "- Stripe prices (monthly/yearly): %s\n", prices)
} }

View File

@ -36,7 +36,8 @@ func TestCLI_Tier_AddListChangeDelete(t *testing.T) {
"--attachment-expiry-duration=7h", "--attachment-expiry-duration=7h",
"--attachment-total-size-limit=10G", "--attachment-total-size-limit=10G",
"--attachment-bandwidth-limit=100G", "--attachment-bandwidth-limit=100G",
"--stripe-price-id=price_991", "--stripe-monthly-price-id=price_991",
"--stripe-yearly-price-id=price_992",
"pro", "pro",
)) ))
require.Contains(t, stderr.String(), "- Message limit: 999") require.Contains(t, stderr.String(), "- Message limit: 999")
@ -46,7 +47,7 @@ func TestCLI_Tier_AddListChangeDelete(t *testing.T) {
require.Contains(t, stderr.String(), "- Attachment file size limit: 100.0 MB") require.Contains(t, stderr.String(), "- Attachment file size limit: 100.0 MB")
require.Contains(t, stderr.String(), "- Attachment expiry duration: 7h") require.Contains(t, stderr.String(), "- Attachment expiry duration: 7h")
require.Contains(t, stderr.String(), "- Attachment total size limit: 10.0 GB") require.Contains(t, stderr.String(), "- Attachment total size limit: 10.0 GB")
require.Contains(t, stderr.String(), "- Stripe price: price_991") require.Contains(t, stderr.String(), "- Stripe prices (monthly/yearly): price_991 / price_992")
app, _, _, stderr = newTestApp() app, _, _, stderr = newTestApp()
require.Nil(t, runTierCommand(app, conf, "remove", "pro")) require.Nil(t, runTierCommand(app, conf, "remove", "pro"))

View File

@ -8,6 +8,7 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release
* Support for publishing to protected topics via email with access tokens ([#612](https://github.com/binwiederhier/ntfy/pull/621), thanks to [@tamcore](https://github.com/tamcore)) * Support for publishing to protected topics via email with access tokens ([#612](https://github.com/binwiederhier/ntfy/pull/621), thanks to [@tamcore](https://github.com/tamcore))
* Support for base64-encoded and nested multipart emails ([#610](https://github.com/binwiederhier/ntfy/issues/610), thanks to [@Robert-litts](https://github.com/Robert-litts)) * Support for base64-encoded and nested multipart emails ([#610](https://github.com/binwiederhier/ntfy/issues/610), thanks to [@Robert-litts](https://github.com/Robert-litts))
* Add support for annual billing intervals (no ticket)
**Bug fixes + maintenance:** **Bug fixes + maintenance:**

View File

@ -45,11 +45,11 @@ type Server struct {
visitors map[string]*visitor // ip:<ip> or user:<user> visitors map[string]*visitor // ip:<ip> or user:<user>
firebaseClient *firebaseClient firebaseClient *firebaseClient
messages int64 messages int64
userManager *user.Manager // Might be nil! userManager *user.Manager // Might be nil!
messageCache *messageCache // Database that stores the messages messageCache *messageCache // Database that stores the messages
fileCache *fileCache // File system based cache that stores attachments fileCache *fileCache // File system based cache that stores attachments
stripe stripeAPI // Stripe API, can be replaced with a mock stripe stripeAPI // Stripe API, can be replaced with a mock
priceCache *util.LookupCache[map[string]string] // Stripe price ID -> formatted price priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
closeChan chan bool closeChan chan bool
mu sync.Mutex mu sync.Mutex
} }

View File

@ -100,6 +100,7 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, r *http.Request, v *vis
Customer: true, Customer: true,
Subscription: u.Billing.StripeSubscriptionID != "", Subscription: u.Billing.StripeSubscriptionID != "",
Status: string(u.Billing.StripeSubscriptionStatus), Status: string(u.Billing.StripeSubscriptionStatus),
Interval: string(u.Billing.StripeSubscriptionInterval),
PaidUntil: u.Billing.StripeSubscriptionPaidUntil.Unix(), PaidUntil: u.Billing.StripeSubscriptionPaidUntil.Unix(),
CancelAt: u.Billing.StripeSubscriptionCancelAt.Unix(), CancelAt: u.Billing.StripeSubscriptionCancelAt.Unix(),
} }

View File

@ -80,14 +80,17 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
return err return err
} }
for _, tier := range tiers { for _, tier := range tiers {
priceStr, ok := prices[tier.StripePriceID] priceMonth, priceYear := prices[tier.StripeMonthlyPriceID], prices[tier.StripeYearlyPriceID]
if tier.StripePriceID == "" || !ok { if priceMonth == 0 || priceYear == 0 { // Only allow tiers that have both prices!
continue continue
} }
response = append(response, &apiAccountBillingTier{ response = append(response, &apiAccountBillingTier{
Code: tier.Code, Code: tier.Code,
Name: tier.Name, Name: tier.Name,
Price: priceStr, Prices: &apiAccountBillingPrices{
Month: priceMonth,
Year: priceYear,
},
Limits: &apiAccountLimits{ Limits: &apiAccountLimits{
Basis: string(visitorLimitBasisTier), Basis: string(visitorLimitBasisTier),
Messages: tier.MessageLimit, Messages: tier.MessageLimit,
@ -117,11 +120,21 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
tier, err := s.userManager.Tier(req.Tier) tier, err := s.userManager.Tier(req.Tier)
if err != nil { if err != nil {
return err return err
} else if tier.StripePriceID == "" { }
var priceID string
if req.Interval == string(stripe.PriceRecurringIntervalMonth) && tier.StripeMonthlyPriceID != "" {
priceID = tier.StripeMonthlyPriceID
} else if req.Interval == string(stripe.PriceRecurringIntervalYear) && tier.StripeYearlyPriceID != "" {
priceID = tier.StripeYearlyPriceID
} else {
return errNotAPaidTier return errNotAPaidTier
} }
logvr(v, r). logvr(v, r).
With(tier). With(tier).
Fields(log.Context{
"stripe_price_id": priceID,
"stripe_subscription_interval": req.Interval,
}).
Tag(tagStripe). Tag(tagStripe).
Info("Creating Stripe checkout flow") Info("Creating Stripe checkout flow")
var stripeCustomerID *string var stripeCustomerID *string
@ -143,7 +156,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
AllowPromotionCodes: stripe.Bool(true), AllowPromotionCodes: stripe.Bool(true),
LineItems: []*stripe.CheckoutSessionLineItemParams{ LineItems: []*stripe.CheckoutSessionLineItemParams{
{ {
Price: stripe.String(tier.StripePriceID), Price: stripe.String(priceID),
Quantity: stripe.Int64(1), Quantity: stripe.Int64(1),
}, },
}, },
@ -180,10 +193,11 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
sub, err := s.stripe.GetSubscription(sess.Subscription.ID) sub, err := s.stripe.GetSubscription(sess.Subscription.ID)
if err != nil { if err != nil {
return err return err
} else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil { } else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil || sub.Items.Data[0].Price.Recurring == nil {
return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "more than one line item in existing subscription") return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "more than one line item in existing subscription")
} }
tier, err := s.userManager.TierByStripePrice(sub.Items.Data[0].Price.ID) priceID, interval := sub.Items.Data[0].Price.ID, sub.Items.Data[0].Price.Recurring.Interval
tier, err := s.userManager.TierByStripePrice(priceID)
if err != nil { if err != nil {
return err return err
} }
@ -197,8 +211,10 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
Tag(tagStripe). Tag(tagStripe).
Fields(log.Context{ Fields(log.Context{
"stripe_customer_id": sess.Customer.ID, "stripe_customer_id": sess.Customer.ID,
"stripe_price_id": priceID,
"stripe_subscription_id": sub.ID, "stripe_subscription_id": sub.ID,
"stripe_subscription_status": string(sub.Status), "stripe_subscription_status": string(sub.Status),
"stripe_subscription_interval": string(interval),
"stripe_subscription_paid_until": sub.CurrentPeriodEnd, "stripe_subscription_paid_until": sub.CurrentPeriodEnd,
}). }).
Info("Stripe checkout flow succeeded, updating user tier and subscription") Info("Stripe checkout flow succeeded, updating user tier and subscription")
@ -213,7 +229,7 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
if _, err := s.stripe.UpdateCustomer(sess.Customer.ID, customerParams); err != nil { if _, err := s.stripe.UpdateCustomer(sess.Customer.ID, customerParams); err != nil {
return err return err
} }
if err := s.updateSubscriptionAndTier(r, v, u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil { if err := s.updateSubscriptionAndTier(r, v, u, tier, sess.Customer.ID, sub.ID, string(sub.Status), string(interval), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
return err return err
} }
http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther) http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther)
@ -235,15 +251,24 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
if err != nil { if err != nil {
return err return err
} }
var priceID string
if req.Interval == string(stripe.PriceRecurringIntervalMonth) && tier.StripeMonthlyPriceID != "" {
priceID = tier.StripeMonthlyPriceID
} else if req.Interval == string(stripe.PriceRecurringIntervalYear) && tier.StripeYearlyPriceID != "" {
priceID = tier.StripeYearlyPriceID
} else {
return errNotAPaidTier
}
logvr(v, r). logvr(v, r).
Tag(tagStripe). Tag(tagStripe).
Fields(log.Context{ Fields(log.Context{
"new_tier_id": tier.ID, "new_tier_id": tier.ID,
"new_tier_name": tier.Name, "new_tier_code": tier.Code,
"new_tier_stripe_price_id": tier.StripePriceID, "new_tier_stripe_price_id": priceID,
"new_tier_stripe_subscription_interval": req.Interval,
// Other stripe_* fields filled by visitor context // Other stripe_* fields filled by visitor context
}). }).
Info("Changing Stripe subscription and billing tier to %s/%s (price %s)", tier.ID, tier.Name, tier.StripePriceID) Info("Changing Stripe subscription and billing tier to %s/%s (price %s, %s)", tier.ID, tier.Name, priceID, req.Interval)
sub, err := s.stripe.GetSubscription(u.Billing.StripeSubscriptionID) sub, err := s.stripe.GetSubscription(u.Billing.StripeSubscriptionID)
if err != nil { if err != nil {
return err return err
@ -252,11 +277,11 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
} }
params := &stripe.SubscriptionParams{ params := &stripe.SubscriptionParams{
CancelAtPeriodEnd: stripe.Bool(false), CancelAtPeriodEnd: stripe.Bool(false),
ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)), ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorAlwaysInvoice)),
Items: []*stripe.SubscriptionItemsParams{ Items: []*stripe.SubscriptionItemsParams{
{ {
ID: stripe.String(sub.Items.Data[0].ID), ID: stripe.String(sub.Items.Data[0].ID),
Price: stripe.String(tier.StripePriceID), Price: stripe.String(priceID),
}, },
}, },
} }
@ -345,20 +370,22 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(r *http.Request,
ev, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event.Data.Raw))) ev, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event.Data.Raw)))
if err != nil { if err != nil {
return err return err
} else if ev.ID == "" || ev.Customer == "" || ev.Status == "" || ev.CurrentPeriodEnd == 0 || ev.Items == nil || len(ev.Items.Data) != 1 || ev.Items.Data[0].Price == nil || ev.Items.Data[0].Price.ID == "" { } else if ev.ID == "" || ev.Customer == "" || ev.Status == "" || ev.CurrentPeriodEnd == 0 || ev.Items == nil || len(ev.Items.Data) != 1 || ev.Items.Data[0].Price == nil || ev.Items.Data[0].Price.ID == "" || ev.Items.Data[0].Price.Recurring == nil {
logvr(v, r).Tag(tagStripe).Field("stripe_request", fmt.Sprintf("%#v", ev)).Warn("Unexpected request from Stripe")
return errHTTPBadRequestBillingRequestInvalid return errHTTPBadRequestBillingRequestInvalid
} }
subscriptionID, priceID := ev.ID, ev.Items.Data[0].Price.ID subscriptionID, priceID, interval := ev.ID, ev.Items.Data[0].Price.ID, ev.Items.Data[0].Price.Recurring.Interval
logvr(v, r). logvr(v, r).
Tag(tagStripe). Tag(tagStripe).
Fields(log.Context{ Fields(log.Context{
"stripe_webhook_type": event.Type, "stripe_webhook_type": event.Type,
"stripe_customer_id": ev.Customer, "stripe_customer_id": ev.Customer,
"stripe_price_id": priceID,
"stripe_subscription_id": ev.ID, "stripe_subscription_id": ev.ID,
"stripe_subscription_status": ev.Status, "stripe_subscription_status": ev.Status,
"stripe_subscription_interval": interval,
"stripe_subscription_paid_until": ev.CurrentPeriodEnd, "stripe_subscription_paid_until": ev.CurrentPeriodEnd,
"stripe_subscription_cancel_at": ev.CancelAt, "stripe_subscription_cancel_at": ev.CancelAt,
"stripe_price_id": priceID,
}). }).
Info("Updating subscription to status %s, with price %s", ev.Status, priceID) Info("Updating subscription to status %s, with price %s", ev.Status, priceID)
userFn := func() (*user.User, error) { userFn := func() (*user.User, error) {
@ -376,7 +403,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(r *http.Request,
if err != nil { if err != nil {
return err return err
} }
if err := s.updateSubscriptionAndTier(r, v, u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil { if err := s.updateSubscriptionAndTier(r, v, u, tier, ev.Customer, subscriptionID, ev.Status, string(interval), ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
return err return err
} }
s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u)) s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
@ -399,14 +426,14 @@ func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(r *http.Request,
Tag(tagStripe). Tag(tagStripe).
Field("stripe_webhook_type", event.Type). Field("stripe_webhook_type", event.Type).
Info("Subscription deleted, downgrading to unpaid tier") Info("Subscription deleted, downgrading to unpaid tier")
if err := s.updateSubscriptionAndTier(r, v, u, nil, ev.Customer, "", "", 0, 0); err != nil { if err := s.updateSubscriptionAndTier(r, v, u, nil, ev.Customer, "", "", "", 0, 0); err != nil {
return err return err
} }
s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u)) s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
return nil return nil
} }
func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error { func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.User, tier *user.Tier, customerID, subscriptionID, status, interval string, paidUntil, cancelAt int64) error {
reservationsLimit := visitorDefaultReservationsLimit reservationsLimit := visitorDefaultReservationsLimit
if tier != nil { if tier != nil {
reservationsLimit = tier.ReservationLimit reservationsLimit = tier.ReservationLimit
@ -423,9 +450,8 @@ func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.
logvr(v, r). logvr(v, r).
Tag(tagStripe). Tag(tagStripe).
Fields(log.Context{ Fields(log.Context{
"new_tier_id": tier.ID, "new_tier_id": tier.ID,
"new_tier_name": tier.Name, "new_tier_code": tier.Code,
"new_tier_stripe_price_id": tier.StripePriceID,
}). }).
Info("Changing tier to tier %s (%s) for user %s", tier.ID, tier.Name, u.Name) Info("Changing tier to tier %s (%s) for user %s", tier.ID, tier.Name, u.Name)
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil { if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
@ -437,6 +463,7 @@ func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.
StripeCustomerID: customerID, StripeCustomerID: customerID,
StripeSubscriptionID: subscriptionID, StripeSubscriptionID: subscriptionID,
StripeSubscriptionStatus: stripe.SubscriptionStatus(status), StripeSubscriptionStatus: stripe.SubscriptionStatus(status),
StripeSubscriptionInterval: stripe.PriceRecurringInterval(interval),
StripeSubscriptionPaidUntil: time.Unix(paidUntil, 0), StripeSubscriptionPaidUntil: time.Unix(paidUntil, 0),
StripeSubscriptionCancelAt: time.Unix(cancelAt, 0), StripeSubscriptionCancelAt: time.Unix(cancelAt, 0),
} }
@ -448,20 +475,16 @@ func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.
// fetchStripePrices contacts the Stripe API to retrieve all prices. This is used by the server to cache the prices // fetchStripePrices contacts the Stripe API to retrieve all prices. This is used by the server to cache the prices
// in memory, and ultimately for the web app to display the price table. // in memory, and ultimately for the web app to display the price table.
func (s *Server) fetchStripePrices() (map[string]string, error) { func (s *Server) fetchStripePrices() (map[string]int64, error) {
log.Debug("Caching prices from Stripe API") log.Debug("Caching prices from Stripe API")
priceMap := make(map[string]string) priceMap := make(map[string]int64)
prices, err := s.stripe.ListPrices(&stripe.PriceListParams{Active: stripe.Bool(true)}) prices, err := s.stripe.ListPrices(&stripe.PriceListParams{Active: stripe.Bool(true)})
if err != nil { if err != nil {
log.Warn("Fetching Stripe prices failed: %s", err.Error()) log.Warn("Fetching Stripe prices failed: %s", err.Error())
return nil, err return nil, err
} }
for _, p := range prices { for _, p := range prices {
if p.UnitAmount%100 == 0 { priceMap[p.ID] = p.UnitAmount
priceMap[p.ID] = fmt.Sprintf("$%d", p.UnitAmount/100)
} else {
priceMap[p.ID] = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100)
}
log.Trace("- Caching price %s = %v", p.ID, priceMap[p.ID]) log.Trace("- Caching price %s = %v", p.ID, priceMap[p.ID])
} }
return priceMap, nil return priceMap, nil

View File

@ -37,7 +37,9 @@ func TestPayments_Tiers(t *testing.T) {
On("ListPrices", mock.Anything). On("ListPrices", mock.Anything).
Return([]*stripe.Price{ Return([]*stripe.Price{
{ID: "price_123", UnitAmount: 500}, {ID: "price_123", UnitAmount: 500},
{ID: "price_124", UnitAmount: 5000},
{ID: "price_456", UnitAmount: 1000}, {ID: "price_456", UnitAmount: 1000},
{ID: "price_457", UnitAmount: 10000},
{ID: "price_999", UnitAmount: 9999}, {ID: "price_999", UnitAmount: 9999},
}, nil) }, nil)
@ -58,7 +60,8 @@ func TestPayments_Tiers(t *testing.T) {
AttachmentFileSizeLimit: 999, AttachmentFileSizeLimit: 999,
AttachmentTotalSizeLimit: 888, AttachmentTotalSizeLimit: 888,
AttachmentExpiryDuration: time.Minute, AttachmentExpiryDuration: time.Minute,
StripePriceID: "price_123", StripeMonthlyPriceID: "price_123",
StripeYearlyPriceID: "price_124",
})) }))
require.Nil(t, s.userManager.AddTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_444", ID: "ti_444",
@ -71,7 +74,8 @@ func TestPayments_Tiers(t *testing.T) {
AttachmentFileSizeLimit: 999111, AttachmentFileSizeLimit: 999111,
AttachmentTotalSizeLimit: 888111, AttachmentTotalSizeLimit: 888111,
AttachmentExpiryDuration: time.Hour, AttachmentExpiryDuration: time.Hour,
StripePriceID: "price_456", StripeMonthlyPriceID: "price_456",
StripeYearlyPriceID: "price_457",
})) }))
response := request(t, s, "GET", "/v1/tiers", "", nil) response := request(t, s, "GET", "/v1/tiers", "", nil)
require.Equal(t, 200, response.Code) require.Equal(t, 200, response.Code)
@ -98,6 +102,8 @@ func TestPayments_Tiers(t *testing.T) {
require.Equal(t, "pro", tier.Code) require.Equal(t, "pro", tier.Code)
require.Equal(t, "Pro", tier.Name) require.Equal(t, "Pro", tier.Name)
require.Equal(t, "tier", tier.Limits.Basis) require.Equal(t, "tier", tier.Limits.Basis)
require.Equal(t, int64(500), tier.Prices.Month)
require.Equal(t, int64(5000), tier.Prices.Year)
require.Equal(t, int64(777), tier.Limits.Reservations) require.Equal(t, int64(777), tier.Limits.Reservations)
require.Equal(t, int64(1000), tier.Limits.Messages) require.Equal(t, int64(1000), tier.Limits.Messages)
require.Equal(t, int64(3600), tier.Limits.MessagesExpiryDuration) require.Equal(t, int64(3600), tier.Limits.MessagesExpiryDuration)
@ -109,6 +115,8 @@ func TestPayments_Tiers(t *testing.T) {
tier = tiers[2] tier = tiers[2]
require.Equal(t, "business", tier.Code) require.Equal(t, "business", tier.Code)
require.Equal(t, "Business", tier.Name) require.Equal(t, "Business", tier.Name)
require.Equal(t, int64(1000), tier.Prices.Month)
require.Equal(t, int64(10000), tier.Prices.Year)
require.Equal(t, "tier", tier.Limits.Basis) require.Equal(t, "tier", tier.Limits.Basis)
require.Equal(t, int64(777333), tier.Limits.Reservations) require.Equal(t, int64(777333), tier.Limits.Reservations)
require.Equal(t, int64(2000), tier.Limits.Messages) require.Equal(t, int64(2000), tier.Limits.Messages)
@ -136,14 +144,14 @@ func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
// Create tier and user // Create tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_123", ID: "ti_123",
Code: "pro", Code: "pro",
StripePriceID: "price_123", StripeMonthlyPriceID: "price_123",
})) }))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
// Create subscription // Create subscription
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{ response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro", "interval": "month"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"), "Authorization": util.BasicAuth("phil", "phil"),
}) })
require.Equal(t, 200, response.Code) require.Equal(t, 200, response.Code)
@ -172,9 +180,9 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
// Create tier and user // Create tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_123", ID: "ti_123",
Code: "pro", Code: "pro",
StripePriceID: "price_123", StripeMonthlyPriceID: "price_123",
})) }))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
@ -187,7 +195,7 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
require.Nil(t, s.userManager.ChangeBilling(u.Name, billing)) require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
// Create subscription // Create subscription
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{ response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro", "interval": "month"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"), "Authorization": util.BasicAuth("phil", "phil"),
}) })
require.Equal(t, 200, response.Code) require.Equal(t, 200, response.Code)
@ -214,9 +222,9 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
// Create tier and user // Create tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_123", ID: "ti_123",
Code: "pro", Code: "pro",
StripePriceID: "price_123", StripeMonthlyPriceID: "price_123",
})) }))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
@ -267,7 +275,7 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
require.Nil(t, s.userManager.AddTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_123", ID: "ti_123",
Code: "starter", Code: "starter",
StripePriceID: "price_1234", StripeMonthlyPriceID: "price_1234",
ReservationLimit: 1, ReservationLimit: 1,
MessageLimit: 220, // 220 * 5% = 11 requests before rate limiting kicks in MessageLimit: 220, // 220 * 5% = 11 requests before rate limiting kicks in
MessageExpiryDuration: time.Hour, MessageExpiryDuration: time.Hour,
@ -298,7 +306,12 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
Items: &stripe.SubscriptionItemList{ Items: &stripe.SubscriptionItemList{
Data: []*stripe.SubscriptionItem{ Data: []*stripe.SubscriptionItem{
{ {
Price: &stripe.Price{ID: "price_1234"}, Price: &stripe.Price{
ID: "price_1234",
Recurring: &stripe.PriceRecurring{
Interval: stripe.PriceRecurringIntervalMonth,
},
},
}, },
}, },
}, },
@ -333,6 +346,7 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
require.Equal(t, "", u.Billing.StripeCustomerID) require.Equal(t, "", u.Billing.StripeCustomerID)
require.Equal(t, "", u.Billing.StripeSubscriptionID) require.Equal(t, "", u.Billing.StripeSubscriptionID)
require.Equal(t, stripe.SubscriptionStatus(""), u.Billing.StripeSubscriptionStatus) require.Equal(t, stripe.SubscriptionStatus(""), u.Billing.StripeSubscriptionStatus)
require.Equal(t, stripe.PriceRecurringInterval(""), u.Billing.StripeSubscriptionInterval)
require.Equal(t, int64(0), u.Billing.StripeSubscriptionPaidUntil.Unix()) require.Equal(t, int64(0), u.Billing.StripeSubscriptionPaidUntil.Unix())
require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix()) require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
require.Equal(t, int64(0), u.Stats.Messages) // Messages and emails are not persisted for no-tier users! require.Equal(t, int64(0), u.Stats.Messages) // Messages and emails are not persisted for no-tier users!
@ -349,6 +363,7 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
require.Equal(t, "acct_5555", u.Billing.StripeCustomerID) require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID) require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus)
require.Equal(t, stripe.PriceRecurringIntervalMonth, u.Billing.StripeSubscriptionInterval)
require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix()) require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix())
require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix()) require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
require.Equal(t, int64(0), u.Stats.Messages) require.Equal(t, int64(0), u.Stats.Messages)
@ -423,7 +438,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
require.Nil(t, s.userManager.AddTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_1", ID: "ti_1",
Code: "starter", Code: "starter",
StripePriceID: "price_1234", // ! StripeMonthlyPriceID: "price_1234", // !
ReservationLimit: 1, // ! ReservationLimit: 1, // !
MessageLimit: 100, MessageLimit: 100,
MessageExpiryDuration: time.Hour, MessageExpiryDuration: time.Hour,
@ -435,7 +450,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
require.Nil(t, s.userManager.AddTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_2", ID: "ti_2",
Code: "pro", Code: "pro",
StripePriceID: "price_1111", // ! StripeMonthlyPriceID: "price_1111", // !
ReservationLimit: 3, // ! ReservationLimit: 3, // !
MessageLimit: 200, MessageLimit: 200,
MessageExpiryDuration: time.Hour, MessageExpiryDuration: time.Hour,
@ -457,6 +472,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
StripeCustomerID: "acct_5555", StripeCustomerID: "acct_5555",
StripeSubscriptionID: "sub_1234", StripeSubscriptionID: "sub_1234",
StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue, StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue,
StripeSubscriptionInterval: stripe.PriceRecurringIntervalMonth,
StripeSubscriptionPaidUntil: time.Unix(123, 0), StripeSubscriptionPaidUntil: time.Unix(123, 0),
StripeSubscriptionCancelAt: time.Unix(456, 0), StripeSubscriptionCancelAt: time.Unix(456, 0),
} }
@ -499,9 +515,10 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
require.Equal(t, "starter", u.Tier.Code) // Not "pro" require.Equal(t, "starter", u.Tier.Code) // Not "pro"
require.Equal(t, "acct_5555", u.Billing.StripeCustomerID) require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID) require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) // Not "past_due" require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) // Not "past_due"
require.Equal(t, int64(1674268231), u.Billing.StripeSubscriptionPaidUntil.Unix()) // Updated require.Equal(t, stripe.PriceRecurringIntervalYear, u.Billing.StripeSubscriptionInterval) // Not "month"
require.Equal(t, int64(1674299999), u.Billing.StripeSubscriptionCancelAt.Unix()) // Updated require.Equal(t, int64(1674268231), u.Billing.StripeSubscriptionPaidUntil.Unix()) // Updated
require.Equal(t, int64(1674299999), u.Billing.StripeSubscriptionCancelAt.Unix()) // Updated
// Verify that reservations were deleted // Verify that reservations were deleted
r, err := s.userManager.Reservations("phil") r, err := s.userManager.Reservations("phil")
@ -546,10 +563,10 @@ func TestPayments_Webhook_Subscription_Deleted(t *testing.T) {
// Create a user with a Stripe subscription and 3 reservations // Create a user with a Stripe subscription and 3 reservations
require.Nil(t, s.userManager.AddTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_1", ID: "ti_1",
Code: "pro", Code: "pro",
StripePriceID: "price_1234", StripeMonthlyPriceID: "price_1234",
ReservationLimit: 1, ReservationLimit: 1,
})) }))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "pro")) require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
@ -562,6 +579,7 @@ func TestPayments_Webhook_Subscription_Deleted(t *testing.T) {
StripeCustomerID: "acct_5555", StripeCustomerID: "acct_5555",
StripeSubscriptionID: "sub_1234", StripeSubscriptionID: "sub_1234",
StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue, StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue,
StripeSubscriptionInterval: stripe.PriceRecurringIntervalMonth,
StripeSubscriptionPaidUntil: time.Unix(123, 0), StripeSubscriptionPaidUntil: time.Unix(123, 0),
StripeSubscriptionCancelAt: time.Unix(0, 0), StripeSubscriptionCancelAt: time.Unix(0, 0),
})) }))
@ -615,11 +633,11 @@ func TestPayments_Subscription_Update_Different_Tier(t *testing.T) {
stripeMock. stripeMock.
On("UpdateSubscription", "sub_123", &stripe.SubscriptionParams{ On("UpdateSubscription", "sub_123", &stripe.SubscriptionParams{
CancelAtPeriodEnd: stripe.Bool(false), CancelAtPeriodEnd: stripe.Bool(false),
ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)), ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorAlwaysInvoice)),
Items: []*stripe.SubscriptionItemsParams{ Items: []*stripe.SubscriptionItemsParams{
{ {
ID: stripe.String("someid_123"), ID: stripe.String("someid_123"),
Price: stripe.String("price_456"), Price: stripe.String("price_457"),
}, },
}, },
}). }).
@ -627,14 +645,16 @@ func TestPayments_Subscription_Update_Different_Tier(t *testing.T) {
// Create tier and user // Create tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_123", ID: "ti_123",
Code: "pro", Code: "pro",
StripePriceID: "price_123", StripeMonthlyPriceID: "price_123",
StripeYearlyPriceID: "price_124",
})) }))
require.Nil(t, s.userManager.AddTier(&user.Tier{ require.Nil(t, s.userManager.AddTier(&user.Tier{
ID: "ti_456", ID: "ti_456",
Code: "business", Code: "business",
StripePriceID: "price_456", StripeMonthlyPriceID: "price_456",
StripeYearlyPriceID: "price_457",
})) }))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "pro")) require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
@ -644,7 +664,7 @@ func TestPayments_Subscription_Update_Different_Tier(t *testing.T) {
})) }))
// Call endpoint to change subscription // Call endpoint to change subscription
rr := request(t, s, "PUT", "/v1/account/billing/subscription", `{"tier":"business"}`, map[string]string{ rr := request(t, s, "PUT", "/v1/account/billing/subscription", `{"tier":"business","interval":"year"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"), "Authorization": util.BasicAuth("phil", "phil"),
}) })
require.Equal(t, 200, rr.Code) require.Equal(t, 200, rr.Code)
@ -795,7 +815,10 @@ const subscriptionUpdatedEventJSON = `
"data": [ "data": [
{ {
"price": { "price": {
"id": "price_1234" "id": "price_1234",
"recurring": {
"interval": "year"
}
} }
} }
] ]
@ -818,7 +841,10 @@ const subscriptionDeletedEventJSON = `
"data": [ "data": [
{ {
"price": { "price": {
"id": "price_1234" "id": "price_1234",
"recurring": {
"interval": "month"
}
} }
} }
] ]

View File

@ -309,6 +309,7 @@ type apiAccountBilling struct {
Customer bool `json:"customer"` Customer bool `json:"customer"`
Subscription bool `json:"subscription"` Subscription bool `json:"subscription"`
Status string `json:"status,omitempty"` Status string `json:"status,omitempty"`
Interval string `json:"interval,omitempty"`
PaidUntil int64 `json:"paid_until,omitempty"` PaidUntil int64 `json:"paid_until,omitempty"`
CancelAt int64 `json:"cancel_at,omitempty"` CancelAt int64 `json:"cancel_at,omitempty"`
} }
@ -343,11 +344,16 @@ type apiConfigResponse struct {
DisallowedTopics []string `json:"disallowed_topics"` DisallowedTopics []string `json:"disallowed_topics"`
} }
type apiAccountBillingPrices struct {
Month int64 `json:"month"`
Year int64 `json:"year"`
}
type apiAccountBillingTier struct { type apiAccountBillingTier struct {
Code string `json:"code,omitempty"` Code string `json:"code,omitempty"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
Price string `json:"price,omitempty"` Prices *apiAccountBillingPrices `json:"prices,omitempty"`
Limits *apiAccountLimits `json:"limits"` Limits *apiAccountLimits `json:"limits"`
} }
type apiAccountBillingSubscriptionCreateResponse struct { type apiAccountBillingSubscriptionCreateResponse struct {
@ -355,7 +361,8 @@ type apiAccountBillingSubscriptionCreateResponse struct {
} }
type apiAccountBillingSubscriptionChangeRequest struct { type apiAccountBillingSubscriptionChangeRequest struct {
Tier string `json:"tier"` Tier string `json:"tier"`
Interval string `json:"interval"`
} }
type apiAccountBillingPortalRedirectResponse struct { type apiAccountBillingPortalRedirectResponse struct {
@ -385,7 +392,10 @@ type apiStripeSubscriptionUpdatedEvent struct {
Items *struct { Items *struct {
Data []*struct { Data []*struct {
Price *struct { Price *struct {
ID string `json:"id"` ID string `json:"id"`
Recurring *struct {
Interval string `json:"interval"`
} `json:"recurring"`
} `json:"price"` } `json:"price"`
} `json:"data"` } `json:"data"`
} `json:"items"` } `json:"items"`

View File

@ -46,7 +46,8 @@ var (
// Manager-related queries // Manager-related queries
const ( const (
createTablesQueriesNoTx = ` createTablesQueries = `
BEGIN;
CREATE TABLE IF NOT EXISTS tier ( CREATE TABLE IF NOT EXISTS tier (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
code TEXT NOT NULL, code TEXT NOT NULL,
@ -59,10 +60,12 @@ const (
attachment_total_size_limit INT NOT NULL, attachment_total_size_limit INT NOT NULL,
attachment_expiry_duration INT NOT NULL, attachment_expiry_duration INT NOT NULL,
attachment_bandwidth_limit INT NOT NULL, attachment_bandwidth_limit INT NOT NULL,
stripe_price_id TEXT stripe_monthly_price_id TEXT,
stripe_yearly_price_id TEXT
); );
CREATE UNIQUE INDEX idx_tier_code ON tier (code); CREATE UNIQUE INDEX idx_tier_code ON tier (code);
CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id); CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
CREATE TABLE IF NOT EXISTS user ( CREATE TABLE IF NOT EXISTS user (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
tier_id TEXT, tier_id TEXT,
@ -76,6 +79,7 @@ const (
stripe_customer_id TEXT, stripe_customer_id TEXT,
stripe_subscription_id TEXT, stripe_subscription_id TEXT,
stripe_subscription_status TEXT, stripe_subscription_status TEXT,
stripe_subscription_interval TEXT,
stripe_subscription_paid_until INT, stripe_subscription_paid_until INT,
stripe_subscription_cancel_at INT, stripe_subscription_cancel_at INT,
created INT NOT NULL, created INT NOT NULL,
@ -112,33 +116,33 @@ const (
INSERT INTO user (id, user, pass, role, sync_topic, created) INSERT INTO user (id, user, pass, role, sync_topic, created)
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', UNIXEPOCH()) VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', UNIXEPOCH())
ON CONFLICT (id) DO NOTHING; ON CONFLICT (id) DO NOTHING;
COMMIT;
` `
createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;`
builtinStartupQueries = ` builtinStartupQueries = `
PRAGMA foreign_keys = ON; PRAGMA foreign_keys = ON;
` `
selectUserByIDQuery = ` selectUserByIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u FROM user u
LEFT JOIN tier t on t.id = u.tier_id LEFT JOIN tier t on t.id = u.tier_id
WHERE u.id = ? WHERE u.id = ?
` `
selectUserByNameQuery = ` selectUserByNameQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u FROM user u
LEFT JOIN tier t on t.id = u.tier_id LEFT JOIN tier t on t.id = u.tier_id
WHERE user = ? WHERE user = ?
` `
selectUserByTokenQuery = ` selectUserByTokenQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u FROM user u
JOIN user_token tk on u.id = tk.user_id JOIN user_token tk on u.id = tk.user_id
LEFT JOIN tier t on t.id = u.tier_id LEFT JOIN tier t on t.id = u.tier_id
WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?) WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?)
` `
selectUserByStripeCustomerIDQuery = ` selectUserByStripeCustomerIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u FROM user u
LEFT JOIN tier t on t.id = u.tier_id LEFT JOIN tier t on t.id = u.tier_id
WHERE u.stripe_customer_id = ? WHERE u.stripe_customer_id = ?
@ -246,27 +250,27 @@ const (
` `
insertTierQuery = ` insertTierQuery = `
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id) INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
` `
updateTierQuery = ` updateTierQuery = `
UPDATE tier UPDATE tier
SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_price_id = ? SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_monthly_price_id = ?, stripe_yearly_price_id = ?
WHERE code = ? WHERE code = ?
` `
selectTiersQuery = ` selectTiersQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
FROM tier FROM tier
` `
selectTierByCodeQuery = ` selectTierByCodeQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
FROM tier FROM tier
WHERE code = ? WHERE code = ?
` `
selectTierByPriceIDQuery = ` selectTierByPriceIDQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
FROM tier FROM tier
WHERE stripe_price_id = ? WHERE (stripe_monthly_price_id = ? OR stripe_yearly_price_id = ?)
` `
updateUserTierQuery = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?` updateUserTierQuery = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?`
deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?` deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
@ -274,21 +278,86 @@ const (
updateBillingQuery = ` updateBillingQuery = `
UPDATE user UPDATE user
SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ? SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_interval = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ?
WHERE user = ? WHERE user = ?
` `
) )
// Schema management queries // Schema management queries
const ( const (
currentSchemaVersion = 2 currentSchemaVersion = 3
insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)` insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1` updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1` selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
// 1 -> 2 (complex migration!) // 1 -> 2 (complex migration!)
migrate1To2RenameUserTableQueryNoTx = ` migrate1To2CreateTablesQueries = `
ALTER TABLE user RENAME TO user_old; ALTER TABLE user RENAME TO user_old;
CREATE TABLE IF NOT EXISTS tier (
id TEXT PRIMARY KEY,
code TEXT NOT NULL,
name TEXT NOT NULL,
messages_limit INT NOT NULL,
messages_expiry_duration INT NOT NULL,
emails_limit INT NOT NULL,
reservations_limit INT NOT NULL,
attachment_file_size_limit INT NOT NULL,
attachment_total_size_limit INT NOT NULL,
attachment_expiry_duration INT NOT NULL,
attachment_bandwidth_limit INT NOT NULL,
stripe_price_id TEXT
);
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
CREATE TABLE IF NOT EXISTS user (
id TEXT PRIMARY KEY,
tier_id TEXT,
user TEXT NOT NULL,
pass TEXT NOT NULL,
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
prefs JSON NOT NULL DEFAULT '{}',
sync_topic TEXT NOT NULL,
stats_messages INT NOT NULL DEFAULT (0),
stats_emails INT NOT NULL DEFAULT (0),
stripe_customer_id TEXT,
stripe_subscription_id TEXT,
stripe_subscription_status TEXT,
stripe_subscription_paid_until INT,
stripe_subscription_cancel_at INT,
created INT NOT NULL,
deleted INT,
FOREIGN KEY (tier_id) REFERENCES tier (id)
);
CREATE UNIQUE INDEX idx_user ON user (user);
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
CREATE TABLE IF NOT EXISTS user_access (
user_id TEXT NOT NULL,
topic TEXT NOT NULL,
read INT NOT NULL,
write INT NOT NULL,
owner_user_id INT,
PRIMARY KEY (user_id, topic),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS user_token (
user_id TEXT NOT NULL,
token TEXT NOT NULL,
label TEXT NOT NULL,
last_access INT NOT NULL,
last_origin TEXT NOT NULL,
expires INT NOT NULL,
PRIMARY KEY (user_id, token),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO user (id, user, pass, role, sync_topic, created)
VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH())
ON CONFLICT (id) DO NOTHING;
` `
migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old` migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
migrate1To2InsertUserNoTx = ` migrate1To2InsertUserNoTx = `
@ -304,11 +373,22 @@ const (
DROP TABLE access; DROP TABLE access;
DROP TABLE user_old; DROP TABLE user_old;
` `
// 2 -> 3
migrate2To3UpdateQueries = `
ALTER TABLE user ADD COLUMN stripe_subscription_interval TEXT;
ALTER TABLE tier RENAME COLUMN stripe_price_id TO stripe_monthly_price_id;
ALTER TABLE tier ADD COLUMN stripe_yearly_price_id TEXT;
DROP INDEX IF EXISTS idx_tier_price_id;
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
`
) )
var ( var (
migrations = map[int]func(db *sql.DB) error{ migrations = map[int]func(db *sql.DB) error{
1: migrateFrom1, 1: migrateFrom1,
2: migrateFrom2,
} }
) )
@ -805,13 +885,13 @@ func (a *Manager) userByToken(token string) (*User, error) {
func (a *Manager) readUser(rows *sql.Rows) (*User, error) { func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
defer rows.Close() defer rows.Close()
var id, username, hash, role, prefs, syncTopic string var id, username, hash, role, prefs, syncTopic string
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierID, tierCode, tierName sql.NullString var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString
var messages, emails int64 var messages, emails int64
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64 var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
if !rows.Next() { if !rows.Next() {
return nil, ErrUserNotFound return nil, ErrUserNotFound
} }
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil { if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
return nil, err return nil, err
} else if err := rows.Err(); err != nil { } else if err := rows.Err(); err != nil {
return nil, err return nil, err
@ -828,11 +908,12 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
Emails: emails, Emails: emails,
}, },
Billing: &Billing{ Billing: &Billing{
StripeCustomerID: stripeCustomerID.String, // May be empty StripeCustomerID: stripeCustomerID.String, // May be empty
StripeSubscriptionID: stripeSubscriptionID.String, // May be empty StripeSubscriptionID: stripeSubscriptionID.String, // May be empty
StripeSubscriptionStatus: stripe.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty StripeSubscriptionStatus: stripe.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty
StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero StripeSubscriptionInterval: stripe.PriceRecurringInterval(stripeSubscriptionInterval.String), // May be empty
StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero
StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero
}, },
Deleted: deleted.Valid, Deleted: deleted.Valid,
} }
@ -853,7 +934,8 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64, AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
StripePriceID: stripePriceID.String, // May be empty StripeMonthlyPriceID: stripeMonthlyPriceID.String, // May be empty
StripeYearlyPriceID: stripeYearlyPriceID.String, // May be empty
} }
} }
return user, nil return user, nil
@ -1134,7 +1216,7 @@ func (a *Manager) AddTier(tier *Tier) error {
if tier.ID == "" { if tier.ID == "" {
tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength) tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
} }
if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripePriceID)); err != nil { if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID)); err != nil {
return err return err
} }
return nil return nil
@ -1142,7 +1224,7 @@ func (a *Manager) AddTier(tier *Tier) error {
// UpdateTier updates a tier's properties in the database // UpdateTier updates a tier's properties in the database
func (a *Manager) UpdateTier(tier *Tier) error { func (a *Manager) UpdateTier(tier *Tier) error {
if _, err := a.db.Exec(updateTierQuery, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripePriceID), tier.Code); err != nil { if _, err := a.db.Exec(updateTierQuery, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID), tier.Code); err != nil {
return err return err
} }
return nil return nil
@ -1162,7 +1244,7 @@ func (a *Manager) RemoveTier(code string) error {
// ChangeBilling updates a user's billing fields, namely the Stripe customer ID, and subscription information // ChangeBilling updates a user's billing fields, namely the Stripe customer ID, and subscription information
func (a *Manager) ChangeBilling(username string, billing *Billing) error { func (a *Manager) ChangeBilling(username string, billing *Billing) error {
if _, err := a.db.Exec(updateBillingQuery, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil { if _, err := a.db.Exec(updateBillingQuery, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullString(string(billing.StripeSubscriptionInterval)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil {
return err return err
} }
return nil return nil
@ -1200,7 +1282,7 @@ func (a *Manager) Tier(code string) (*Tier, error) {
// TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist // TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) { func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
rows, err := a.db.Query(selectTierByPriceIDQuery, priceID) rows, err := a.db.Query(selectTierByPriceIDQuery, priceID, priceID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1210,12 +1292,12 @@ func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) { func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
var id, code, name string var id, code, name string
var stripePriceID sql.NullString var stripeMonthlyPriceID, stripeYearlyPriceID sql.NullString
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64 var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64
if !rows.Next() { if !rows.Next() {
return nil, ErrTierNotFound return nil, ErrTierNotFound
} }
if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil { if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
return nil, err return nil, err
} else if err := rows.Err(); err != nil { } else if err := rows.Err(); err != nil {
return nil, err return nil, err
@ -1233,7 +1315,8 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64, AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
StripePriceID: stripePriceID.String, // May be empty StripeMonthlyPriceID: stripeMonthlyPriceID.String, // May be empty
StripeYearlyPriceID: stripeYearlyPriceID.String, // May be empty
}, nil }, nil
} }
@ -1313,10 +1396,7 @@ func migrateFrom1(db *sql.DB) error {
} }
defer tx.Rollback() defer tx.Rollback()
// Rename user -> user_old, and create new tables // Rename user -> user_old, and create new tables
if _, err := tx.Exec(migrate1To2RenameUserTableQueryNoTx); err != nil { if _, err := tx.Exec(migrate1To2CreateTablesQueries); err != nil {
return err
}
if _, err := tx.Exec(createTablesQueriesNoTx); err != nil {
return err return err
} }
// Insert users from user_old into new user table, with ID and sync_topic // Insert users from user_old into new user table, with ID and sync_topic
@ -1356,6 +1436,22 @@ func migrateFrom1(db *sql.DB) error {
return nil return nil
} }
func migrateFrom2(db *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 2 to 3")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate2To3UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 3); err != nil {
return err
}
return tx.Commit()
}
func nullString(s string) sql.NullString { func nullString(s string) sql.NullString {
if s == "" { if s == "" {
return sql.NullString{} return sql.NullString{}

View File

@ -4,6 +4,7 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stripe/stripe-go/v74"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"heckel.io/ntfy/util" "heckel.io/ntfy/util"
"net/netip" "net/netip"
@ -113,7 +114,8 @@ func TestManager_AddUser_And_Query(t *testing.T) {
require.Nil(t, a.ChangeBilling("user", &Billing{ require.Nil(t, a.ChangeBilling("user", &Billing{
StripeCustomerID: "acct_123", StripeCustomerID: "acct_123",
StripeSubscriptionID: "sub_123", StripeSubscriptionID: "sub_123",
StripeSubscriptionStatus: "active", StripeSubscriptionStatus: stripe.SubscriptionStatusActive,
StripeSubscriptionInterval: stripe.PriceRecurringIntervalMonth,
StripeSubscriptionPaidUntil: time.Now().Add(time.Hour), StripeSubscriptionPaidUntil: time.Now().Add(time.Hour),
StripeSubscriptionCancelAt: time.Unix(0, 0), StripeSubscriptionCancelAt: time.Unix(0, 0),
})) }))
@ -395,7 +397,7 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
require.Nil(t, a.AddTier(&Tier{ require.Nil(t, a.AddTier(&Tier{
Code: "pro", Code: "pro",
Name: "ntfy Pro", Name: "ntfy Pro",
StripePriceID: "price123", StripeMonthlyPriceID: "price123",
MessageLimit: 5_000, MessageLimit: 5_000,
MessageExpiryDuration: 3 * 24 * time.Hour, MessageExpiryDuration: 3 * 24 * time.Hour,
EmailLimit: 50, EmailLimit: 50,
@ -761,7 +763,7 @@ func TestManager_Tier_Create_Update_List_Delete(t *testing.T) {
AttachmentTotalSizeLimit: 1, AttachmentTotalSizeLimit: 1,
AttachmentExpiryDuration: time.Second, AttachmentExpiryDuration: time.Second,
AttachmentBandwidthLimit: 1, AttachmentBandwidthLimit: 1,
StripePriceID: "price_1", StripeMonthlyPriceID: "price_1",
})) }))
require.Nil(t, a.AddTier(&Tier{ require.Nil(t, a.AddTier(&Tier{
Code: "pro", Code: "pro",
@ -774,7 +776,7 @@ func TestManager_Tier_Create_Update_List_Delete(t *testing.T) {
AttachmentTotalSizeLimit: 123123, AttachmentTotalSizeLimit: 123123,
AttachmentExpiryDuration: 10800 * time.Second, AttachmentExpiryDuration: 10800 * time.Second,
AttachmentBandwidthLimit: 21474836480, AttachmentBandwidthLimit: 21474836480,
StripePriceID: "price_2", StripeMonthlyPriceID: "price_2",
})) }))
require.Nil(t, a.AddUser("phil", "phil", RoleUser)) require.Nil(t, a.AddUser("phil", "phil", RoleUser))
require.Nil(t, a.ChangeTier("phil", "pro")) require.Nil(t, a.ChangeTier("phil", "pro"))
@ -800,7 +802,7 @@ func TestManager_Tier_Create_Update_List_Delete(t *testing.T) {
require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit) require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit)
require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration) require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration)
require.Equal(t, int64(21474836480), ti.AttachmentBandwidthLimit) require.Equal(t, int64(21474836480), ti.AttachmentBandwidthLimit)
require.Equal(t, "price_2", ti.StripePriceID) require.Equal(t, "price_2", ti.StripeMonthlyPriceID)
// Update tier // Update tier
ti.EmailLimit = 999999 ti.EmailLimit = 999999
@ -822,7 +824,7 @@ func TestManager_Tier_Create_Update_List_Delete(t *testing.T) {
require.Equal(t, int64(1), ti.AttachmentTotalSizeLimit) require.Equal(t, int64(1), ti.AttachmentTotalSizeLimit)
require.Equal(t, time.Second, ti.AttachmentExpiryDuration) require.Equal(t, time.Second, ti.AttachmentExpiryDuration)
require.Equal(t, int64(1), ti.AttachmentBandwidthLimit) require.Equal(t, int64(1), ti.AttachmentBandwidthLimit)
require.Equal(t, "price_1", ti.StripePriceID) require.Equal(t, "price_1", ti.StripeMonthlyPriceID)
ti = tiers[1] ti = tiers[1]
require.Equal(t, "pro", ti.Code) require.Equal(t, "pro", ti.Code)
@ -835,7 +837,7 @@ func TestManager_Tier_Create_Update_List_Delete(t *testing.T) {
require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit) require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit)
require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration) require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration)
require.Equal(t, int64(21474836480), ti.AttachmentBandwidthLimit) require.Equal(t, int64(21474836480), ti.AttachmentBandwidthLimit)
require.Equal(t, "price_2", ti.StripePriceID) require.Equal(t, "price_2", ti.StripeMonthlyPriceID)
ti, err = a.TierByStripePrice("price_1") ti, err = a.TierByStripePrice("price_1")
require.Nil(t, err) require.Nil(t, err)
@ -849,7 +851,7 @@ func TestManager_Tier_Create_Update_List_Delete(t *testing.T) {
require.Equal(t, int64(1), ti.AttachmentTotalSizeLimit) require.Equal(t, int64(1), ti.AttachmentTotalSizeLimit)
require.Equal(t, time.Second, ti.AttachmentExpiryDuration) require.Equal(t, time.Second, ti.AttachmentExpiryDuration)
require.Equal(t, int64(1), ti.AttachmentBandwidthLimit) require.Equal(t, int64(1), ti.AttachmentBandwidthLimit)
require.Equal(t, "price_1", ti.StripePriceID) require.Equal(t, "price_1", ti.StripeMonthlyPriceID)
// Cannot remove tier, since user has this tier // Cannot remove tier, since user has this tier
require.Error(t, a.RemoveTier("pro")) require.Error(t, a.RemoveTier("pro"))

View File

@ -91,15 +91,17 @@ type Tier struct {
AttachmentTotalSizeLimit int64 // Total file size for all files of this user (bytes) AttachmentTotalSizeLimit int64 // Total file size for all files of this user (bytes)
AttachmentExpiryDuration time.Duration // Duration after which attachments will be deleted AttachmentExpiryDuration time.Duration // Duration after which attachments will be deleted
AttachmentBandwidthLimit int64 // Daily bandwidth limit for the user AttachmentBandwidthLimit int64 // Daily bandwidth limit for the user
StripePriceID string // Price ID for paid tiers (price_...) StripeMonthlyPriceID string // Monthly price ID for paid tiers (price_...)
StripeYearlyPriceID string // Yearly price ID for paid tiers (price_...)
} }
// Context returns fields for the log // Context returns fields for the log
func (t *Tier) Context() log.Context { func (t *Tier) Context() log.Context {
return log.Context{ return log.Context{
"tier_id": t.ID, "tier_id": t.ID,
"tier_code": t.Code, "tier_code": t.Code,
"stripe_price_id": t.StripePriceID, "stripe_monthly_price_id": t.StripeMonthlyPriceID,
"stripe_yearly_price_id": t.StripeYearlyPriceID,
} }
} }
@ -136,6 +138,7 @@ type Billing struct {
StripeCustomerID string StripeCustomerID string
StripeSubscriptionID string StripeSubscriptionID string
StripeSubscriptionStatus stripe.SubscriptionStatus StripeSubscriptionStatus stripe.SubscriptionStatus
StripeSubscriptionInterval stripe.PriceRecurringInterval
StripeSubscriptionPaidUntil time.Time StripeSubscriptionPaidUntil time.Time
StripeSubscriptionCancelAt time.Time StripeSubscriptionCancelAt time.Time
} }

View File

@ -49,12 +49,15 @@ func TestAllowedTier(t *testing.T) {
func TestTierContext(t *testing.T) { func TestTierContext(t *testing.T) {
tier := &Tier{ tier := &Tier{
ID: "ti_abc", ID: "ti_abc",
Code: "pro", Code: "pro",
StripePriceID: "price_123", StripeMonthlyPriceID: "price_123",
StripeYearlyPriceID: "price_456",
} }
context := tier.Context() context := tier.Context()
require.Equal(t, "ti_abc", context["tier_id"]) require.Equal(t, "ti_abc", context["tier_id"])
require.Equal(t, "pro", context["tier_code"]) require.Equal(t, "pro", context["tier_code"])
require.Equal(t, "price_123", context["stripe_price_id"]) require.Equal(t, "price_123", context["stripe_monthly_price_id"])
require.Equal(t, "price_456", context["stripe_yearly_price_id"])
} }

View File

@ -193,6 +193,8 @@
"account_basics_tier_admin_suffix_no_tier": "(no tier)", "account_basics_tier_admin_suffix_no_tier": "(no tier)",
"account_basics_tier_basic": "Basic", "account_basics_tier_basic": "Basic",
"account_basics_tier_free": "Free", "account_basics_tier_free": "Free",
"account_basics_tier_interval_monthly": "monthly",
"account_basics_tier_interval_yearly": "annually",
"account_basics_tier_upgrade_button": "Upgrade to Pro", "account_basics_tier_upgrade_button": "Upgrade to Pro",
"account_basics_tier_change_button": "Change", "account_basics_tier_change_button": "Change",
"account_basics_tier_paid_until": "Subscription paid until {{date}}, and will auto-renew", "account_basics_tier_paid_until": "Subscription paid until {{date}}, and will auto-renew",
@ -215,15 +217,21 @@
"account_delete_dialog_button_submit": "Permanently delete account", "account_delete_dialog_button_submit": "Permanently delete account",
"account_delete_dialog_billing_warning": "Deleting your account also cancels your billing subscription immediately. You will not have access to the billing dashboard anymore.", "account_delete_dialog_billing_warning": "Deleting your account also cancels your billing subscription immediately. You will not have access to the billing dashboard anymore.",
"account_upgrade_dialog_title": "Change account tier", "account_upgrade_dialog_title": "Change account tier",
"account_upgrade_dialog_interval_monthly": "Monthly",
"account_upgrade_dialog_interval_yearly": "Annually",
"account_upgrade_dialog_cancel_warning": "This will <strong>cancel your subscription</strong>, and downgrade your account on {{date}}. On that date, topic reservations as well as messages cached on the server <strong>will be deleted</strong>.", "account_upgrade_dialog_cancel_warning": "This will <strong>cancel your subscription</strong>, and downgrade your account on {{date}}. On that date, topic reservations as well as messages cached on the server <strong>will be deleted</strong>.",
"account_upgrade_dialog_proration_info": "<strong>Proration</strong>: When switching between paid plans, the price difference will be charged or refunded in the next invoice. You will not receive another invoice until the end of the next billing period.", "account_upgrade_dialog_proration_info": "<strong>Proration</strong>: When upgrading between paid plans, the price difference will be <strong>charged immediately</strong>. When downgrading to a lower tier, the balance will be used to pay for future billing periods.",
"account_upgrade_dialog_reservations_warning_one": "The selected tier allows fewer reserved topics than your current tier. Before changing your tier, <strong>please delete at least one reservation</strong>. You can remove reservations in the <Link>Settings</Link>.", "account_upgrade_dialog_reservations_warning_one": "The selected tier allows fewer reserved topics than your current tier. Before changing your tier, <strong>please delete at least one reservation</strong>. You can remove reservations in the <Link>Settings</Link>.",
"account_upgrade_dialog_reservations_warning_other": "The selected tier allows fewer reserved topics than your current tier. Before changing your tier, <strong>please delete at least {{count}} reservations</strong>. You can remove reservations in the <Link>Settings</Link>.", "account_upgrade_dialog_reservations_warning_other": "The selected tier allows fewer reserved topics than your current tier. Before changing your tier, <strong>please delete at least {{count}} reservations</strong>. You can remove reservations in the <Link>Settings</Link>.",
"account_upgrade_dialog_tier_features_reservations": "{{reservations}} reserved topics", "account_upgrade_dialog_tier_features_reservations": "{{reservations}} reserved topics",
"account_upgrade_dialog_tier_features_no_reservations": "No reserved topics",
"account_upgrade_dialog_tier_features_messages": "{{messages}} daily messages", "account_upgrade_dialog_tier_features_messages": "{{messages}} daily messages",
"account_upgrade_dialog_tier_features_emails": "{{emails}} daily emails", "account_upgrade_dialog_tier_features_emails": "{{emails}} daily emails",
"account_upgrade_dialog_tier_features_attachment_file_size": "{{filesize}} per file", "account_upgrade_dialog_tier_features_attachment_file_size": "{{filesize}} per file",
"account_upgrade_dialog_tier_features_attachment_total_size": "{{totalsize}} total storage", "account_upgrade_dialog_tier_features_attachment_total_size": "{{totalsize}} total storage",
"account_upgrade_dialog_tier_price_per_month": "month",
"account_upgrade_dialog_tier_price_billed_monthly": "{{price}} per year. Billed monthly.",
"account_upgrade_dialog_tier_price_billed_yearly": "{{price}} billed annually. Save {{save}}.",
"account_upgrade_dialog_tier_selected_label": "Selected", "account_upgrade_dialog_tier_selected_label": "Selected",
"account_upgrade_dialog_tier_current_label": "Current", "account_upgrade_dialog_tier_current_label": "Current",
"account_upgrade_dialog_button_cancel": "Cancel", "account_upgrade_dialog_button_cancel": "Cancel",

View File

@ -257,23 +257,24 @@ class AccountApi {
return this.tiers; return this.tiers;
} }
async createBillingSubscription(tier) { async createBillingSubscription(tier, interval) {
console.log(`[AccountApi] Creating billing subscription with ${tier}`); console.log(`[AccountApi] Creating billing subscription with ${tier} and interval ${interval}`);
return await this.upsertBillingSubscription("POST", tier) return await this.upsertBillingSubscription("POST", tier, interval)
} }
async updateBillingSubscription(tier) { async updateBillingSubscription(tier, interval) {
console.log(`[AccountApi] Updating billing subscription with ${tier}`); console.log(`[AccountApi] Updating billing subscription with ${tier} and interval ${interval}`);
return await this.upsertBillingSubscription("PUT", tier) return await this.upsertBillingSubscription("PUT", tier, interval)
} }
async upsertBillingSubscription(method, tier) { async upsertBillingSubscription(method, tier, interval) {
const url = accountBillingSubscriptionUrl(config.base_url); const url = accountBillingSubscriptionUrl(config.base_url);
const response = await fetchOrThrow(url, { const response = await fetchOrThrow(url, {
method: method, method: method,
headers: withBearerAuth({}, session.token()), headers: withBearerAuth({}, session.token()),
body: JSON.stringify({ body: JSON.stringify({
tier: tier tier: tier,
interval: interval
}) })
}); });
return await response.json(); // May throw SyntaxError return await response.json(); // May throw SyntaxError
@ -371,6 +372,12 @@ export const SubscriptionStatus = {
PAST_DUE: "past_due" PAST_DUE: "past_due"
}; };
// Maps to stripe.PriceRecurringInterval
export const SubscriptionInterval = {
MONTH: "month",
YEAR: "year"
};
// Maps to user.Permission in user/types.go // Maps to user.Permission in user/types.go
export const Permission = { export const Permission = {
READ_WRITE: "read-write", READ_WRITE: "read-write",

View File

@ -212,6 +212,13 @@ export const formatNumber = (n) => {
return n; return n;
} }
export const formatPrice = (n) => {
if (n % 100 === 0) {
return `$${n/100}`;
}
return `$${(n/100).toPrecision(2)}`;
}
export const openUrl = (url) => { export const openUrl = (url) => {
window.open(url, "_blank", "noopener,noreferrer"); window.open(url, "_blank", "noopener,noreferrer");
}; };

View File

@ -35,7 +35,7 @@ import TextField from "@mui/material/TextField";
import routes from "./routes"; import routes from "./routes";
import IconButton from "@mui/material/IconButton"; import IconButton from "@mui/material/IconButton";
import {formatBytes, formatShortDate, formatShortDateTime, openUrl} from "../app/utils"; import {formatBytes, formatShortDate, formatShortDateTime, openUrl} from "../app/utils";
import accountApi, {LimitBasis, Role, SubscriptionStatus} from "../app/AccountApi"; import accountApi, {LimitBasis, Role, SubscriptionInterval, SubscriptionStatus} from "../app/AccountApi";
import InfoOutlinedIcon from '@mui/icons-material/InfoOutlined'; import InfoOutlinedIcon from '@mui/icons-material/InfoOutlined';
import {Pref, PrefGroup} from "./Pref"; import {Pref, PrefGroup} from "./Pref";
import db from "../app/db"; import db from "../app/db";
@ -248,6 +248,11 @@ const AccountType = () => {
accountType = (config.enable_payments) ? t("account_basics_tier_free") : t("account_basics_tier_basic"); accountType = (config.enable_payments) ? t("account_basics_tier_free") : t("account_basics_tier_basic");
} else { } else {
accountType = account.tier.name; accountType = account.tier.name;
if (account.billing?.interval === SubscriptionInterval.MONTH) {
accountType += ` (${t("account_basics_tier_interval_monthly")})`;
} else if (account.billing?.interval === SubscriptionInterval.YEAR) {
accountType += ` (${t("account_basics_tier_interval_yearly")})`;
}
} }
return ( return (

View File

@ -3,20 +3,20 @@ import {useContext, useEffect, useState} from 'react';
import Dialog from '@mui/material/Dialog'; import Dialog from '@mui/material/Dialog';
import DialogContent from '@mui/material/DialogContent'; import DialogContent from '@mui/material/DialogContent';
import DialogTitle from '@mui/material/DialogTitle'; import DialogTitle from '@mui/material/DialogTitle';
import {Alert, CardActionArea, CardContent, ListItem, useMediaQuery} from "@mui/material"; import {Alert, Badge, CardActionArea, CardContent, Chip, ListItem, Stack, Switch, useMediaQuery} from "@mui/material";
import theme from "./theme"; import theme from "./theme";
import DialogFooter from "./DialogFooter"; import DialogFooter from "./DialogFooter";
import Button from "@mui/material/Button"; import Button from "@mui/material/Button";
import accountApi from "../app/AccountApi"; import accountApi, {SubscriptionInterval} from "../app/AccountApi";
import session from "../app/Session"; import session from "../app/Session";
import routes from "./routes"; import routes from "./routes";
import Card from "@mui/material/Card"; import Card from "@mui/material/Card";
import Typography from "@mui/material/Typography"; import Typography from "@mui/material/Typography";
import {AccountContext} from "./App"; import {AccountContext} from "./App";
import {formatBytes, formatNumber, formatShortDate} from "../app/utils"; import {formatBytes, formatNumber, formatPrice, formatShortDate} from "../app/utils";
import {Trans, useTranslation} from "react-i18next"; import {Trans, useTranslation} from "react-i18next";
import List from "@mui/material/List"; import List from "@mui/material/List";
import {Check} from "@mui/icons-material"; import {Check, Close} from "@mui/icons-material";
import ListItemIcon from "@mui/material/ListItemIcon"; import ListItemIcon from "@mui/material/ListItemIcon";
import ListItemText from "@mui/material/ListItemText"; import ListItemText from "@mui/material/ListItemText";
import Box from "@mui/material/Box"; import Box from "@mui/material/Box";
@ -28,6 +28,7 @@ const UpgradeDialog = (props) => {
const { account } = useContext(AccountContext); // May be undefined! const { account } = useContext(AccountContext); // May be undefined!
const [error, setError] = useState(""); const [error, setError] = useState("");
const [tiers, setTiers] = useState(null); const [tiers, setTiers] = useState(null);
const [interval, setInterval] = useState(account?.billing?.interval || SubscriptionInterval.YEAR);
const [newTierCode, setNewTierCode] = useState(account?.tier?.code); // May be undefined const [newTierCode, setNewTierCode] = useState(account?.tier?.code); // May be undefined
const [loading, setLoading] = useState(false); const [loading, setLoading] = useState(false);
const fullScreen = useMediaQuery(theme.breakpoints.down('sm')); const fullScreen = useMediaQuery(theme.breakpoints.down('sm'));
@ -46,6 +47,7 @@ const UpgradeDialog = (props) => {
const tiersMap = Object.assign(...tiers.map(tier => ({[tier.code]: tier}))); const tiersMap = Object.assign(...tiers.map(tier => ({[tier.code]: tier})));
const newTier = tiersMap[newTierCode]; // May be undefined const newTier = tiersMap[newTierCode]; // May be undefined
const currentTier = account?.tier; // May be undefined const currentTier = account?.tier; // May be undefined
const currentInterval = account?.billing?.interval; // May be undefined
const currentTierCode = currentTier?.code; // May be undefined const currentTierCode = currentTier?.code; // May be undefined
// Figure out buttons, labels and the submit action // Figure out buttons, labels and the submit action
@ -54,7 +56,7 @@ const UpgradeDialog = (props) => {
submitButtonLabel = t("account_upgrade_dialog_button_redirect_signup"); submitButtonLabel = t("account_upgrade_dialog_button_redirect_signup");
submitAction = Action.REDIRECT_SIGNUP; submitAction = Action.REDIRECT_SIGNUP;
banner = null; banner = null;
} else if (currentTierCode === newTierCode) { } else if (currentTierCode === newTierCode && currentInterval === interval) {
submitButtonLabel = t("account_upgrade_dialog_button_update_subscription"); submitButtonLabel = t("account_upgrade_dialog_button_update_subscription");
submitAction = null; submitAction = null;
banner = (currentTierCode) ? Banner.PRORATION_INFO : null; banner = (currentTierCode) ? Banner.PRORATION_INFO : null;
@ -88,10 +90,10 @@ const UpgradeDialog = (props) => {
try { try {
setLoading(true); setLoading(true);
if (submitAction === Action.CREATE_SUBSCRIPTION) { if (submitAction === Action.CREATE_SUBSCRIPTION) {
const response = await accountApi.createBillingSubscription(newTierCode); const response = await accountApi.createBillingSubscription(newTierCode, interval);
window.location.href = response.redirect_url; window.location.href = response.redirect_url;
} else if (submitAction === Action.UPDATE_SUBSCRIPTION) { } else if (submitAction === Action.UPDATE_SUBSCRIPTION) {
await accountApi.updateBillingSubscription(newTierCode); await accountApi.updateBillingSubscription(newTierCode, interval);
} else if (submitAction === Action.CANCEL_SUBSCRIPTION) { } else if (submitAction === Action.CANCEL_SUBSCRIPTION) {
await accountApi.deleteBillingSubscription(); await accountApi.deleteBillingSubscription();
} }
@ -108,15 +110,45 @@ const UpgradeDialog = (props) => {
} }
} }
// Figure out discount
let discount;
if (newTier?.prices) {
discount = Math.round(((newTier.prices.month*12/newTier.prices.year)-1)*100);
} else {
for (const t of tiers) {
if (t.prices) {
discount = Math.round(((t.prices.month*12/t.prices.year)-1)*100);
break;
}
}
}
return ( return (
<Dialog <Dialog
open={props.open} open={props.open}
onClose={props.onCancel} onClose={props.onCancel}
maxWidth="md" maxWidth="lg"
fullWidth
fullScreen={fullScreen} fullScreen={fullScreen}
> >
<DialogTitle>{t("account_upgrade_dialog_title")}</DialogTitle> <DialogTitle>
<div style={{ display: "flex", flexDirection: "row" }}>
<div style={{ flexGrow: 1 }}>{t("account_upgrade_dialog_title")}</div>
<div style={{
display: "flex",
flexDirection: "row",
alignItems: "center",
marginTop: "4px"
}}>
<Typography component="span" variant="subtitle1">{t("account_upgrade_dialog_interval_monthly")}</Typography>
<Switch
checked={interval === SubscriptionInterval.YEAR}
onChange={(ev) => setInterval(ev.target.checked ? SubscriptionInterval.YEAR : SubscriptionInterval.MONTH)}
/>
<Typography component="span" variant="subtitle1">{t("account_upgrade_dialog_interval_yearly")}</Typography>
{discount > 0 && <Chip label={`-${discount}%`} color="primary" size="small" sx={{ marginLeft: "5px" }}/>}
</div>
</div>
</DialogTitle>
<DialogContent> <DialogContent>
<div style={{ <div style={{
display: "flex", display: "flex",
@ -130,24 +162,25 @@ const UpgradeDialog = (props) => {
tier={tier} tier={tier}
current={currentTierCode === tier.code} // tier.code or currentTierCode may be undefined! current={currentTierCode === tier.code} // tier.code or currentTierCode may be undefined!
selected={newTierCode === tier.code} // tier.code may be undefined! selected={newTierCode === tier.code} // tier.code may be undefined!
interval={interval}
onClick={() => setNewTierCode(tier.code)} // tier.code may be undefined! onClick={() => setNewTierCode(tier.code)} // tier.code may be undefined!
/> />
)} )}
</div> </div>
{banner === Banner.CANCEL_WARNING && {banner === Banner.CANCEL_WARNING &&
<Alert severity="warning"> <Alert severity="warning" sx={{ fontSize: "1rem" }}>
<Trans <Trans
i18nKey="account_upgrade_dialog_cancel_warning" i18nKey="account_upgrade_dialog_cancel_warning"
values={{ date: formatShortDate(account?.billing?.paid_until || 0) }} /> values={{ date: formatShortDate(account?.billing?.paid_until || 0) }} />
</Alert> </Alert>
} }
{banner === Banner.PRORATION_INFO && {banner === Banner.PRORATION_INFO &&
<Alert severity="info"> <Alert severity="info" sx={{ fontSize: "1rem" }}>
<Trans i18nKey="account_upgrade_dialog_proration_info" /> <Trans i18nKey="account_upgrade_dialog_proration_info" />
</Alert> </Alert>
} }
{banner === Banner.RESERVATIONS_WARNING && {banner === Banner.RESERVATIONS_WARNING &&
<Alert severity="warning"> <Alert severity="warning" sx={{ fontSize: "1rem" }}>
<Trans <Trans
i18nKey="account_upgrade_dialog_reservations_warning" i18nKey="account_upgrade_dialog_reservations_warning"
count={account?.reservations.length - newTier?.limits.reservations} count={account?.reservations.length - newTier?.limits.reservations}
@ -169,28 +202,37 @@ const UpgradeDialog = (props) => {
const TierCard = (props) => { const TierCard = (props) => {
const { t } = useTranslation(); const { t } = useTranslation();
const tier = props.tier; const tier = props.tier;
let cardStyle, labelStyle, labelText; let cardStyle, labelStyle, labelText;
if (props.selected) { if (props.selected) {
cardStyle = { background: "#eee", border: "2px solid #338574" }; cardStyle = { background: "#eee", border: "3px solid #338574" };
labelStyle = { background: "#338574", color: "white" }; labelStyle = { background: "#338574", color: "white" };
labelText = t("account_upgrade_dialog_tier_selected_label"); labelText = t("account_upgrade_dialog_tier_selected_label");
} else if (props.current) { } else if (props.current) {
cardStyle = { border: "2px solid #eee" }; cardStyle = { border: "3px solid #eee" };
labelStyle = { background: "#eee", color: "black" }; labelStyle = { background: "#eee", color: "black" };
labelText = t("account_upgrade_dialog_tier_current_label"); labelText = t("account_upgrade_dialog_tier_current_label");
} else { } else {
cardStyle = { border: "2px solid transparent" }; cardStyle = { border: "3px solid transparent" };
}
let monthlyPrice;
if (!tier.prices) {
monthlyPrice = 0;
} else if (props.interval === SubscriptionInterval.YEAR) {
monthlyPrice = tier.prices.year/12;
} else if (props.interval === SubscriptionInterval.MONTH) {
monthlyPrice = tier.prices.month;
} }
return ( return (
<Box sx={{ <Box sx={{
m: "7px", m: "7px",
minWidth: "190px", minWidth: "240px",
maxWidth: "250px",
flexGrow: 1, flexGrow: 1,
flexShrink: 1, flexShrink: 1,
flexBasis: 0, flexBasis: 0,
borderRadius: "3px", borderRadius: "5px",
"&:first-of-type": { ml: 0 }, "&:first-of-type": { ml: 0 },
"&:last-of-type": { mr: 0 }, "&:last-of-type": { mr: 0 },
...cardStyle ...cardStyle
@ -208,19 +250,29 @@ const TierCard = (props) => {
...labelStyle ...labelStyle
}}>{labelText}</div> }}>{labelText}</div>
} }
<Typography variant="h5" component="div"> <Typography variant="subtitle1" component="div">
{tier.name || t("account_basics_tier_free")} {tier.name || t("account_basics_tier_free")}
</Typography> </Typography>
<div>
<Typography component="span" variant="h4" sx={{ fontWeight: 500, marginRight: "3px" }}>{formatPrice(monthlyPrice)}</Typography>
{monthlyPrice > 0 && <>/ {t("account_upgrade_dialog_tier_price_per_month")}</>}
</div>
<List dense> <List dense>
{tier.limits.reservations > 0 && <FeatureItem>{t("account_upgrade_dialog_tier_features_reservations", { reservations: tier.limits.reservations })}</FeatureItem>} {tier.limits.reservations > 0 && <Feature>{t("account_upgrade_dialog_tier_features_reservations", { reservations: tier.limits.reservations })}</Feature>}
<FeatureItem>{t("account_upgrade_dialog_tier_features_messages", { messages: formatNumber(tier.limits.messages) })}</FeatureItem> {tier.limits.reservations === 0 && <NoFeature>{t("account_upgrade_dialog_tier_features_no_reservations")}</NoFeature>}
<FeatureItem>{t("account_upgrade_dialog_tier_features_emails", { emails: formatNumber(tier.limits.emails) })}</FeatureItem> <Feature>{t("account_upgrade_dialog_tier_features_messages", { messages: formatNumber(tier.limits.messages) })}</Feature>
<FeatureItem>{t("account_upgrade_dialog_tier_features_attachment_file_size", { filesize: formatBytes(tier.limits.attachment_file_size, 0) })}</FeatureItem> <Feature>{t("account_upgrade_dialog_tier_features_emails", { emails: formatNumber(tier.limits.emails) })}</Feature>
<FeatureItem>{t("account_upgrade_dialog_tier_features_attachment_total_size", { totalsize: formatBytes(tier.limits.attachment_total_size, 0) })}</FeatureItem> <Feature>{t("account_upgrade_dialog_tier_features_attachment_file_size", { filesize: formatBytes(tier.limits.attachment_file_size, 0) })}</Feature>
<Feature>{t("account_upgrade_dialog_tier_features_attachment_total_size", { totalsize: formatBytes(tier.limits.attachment_total_size, 0) })}</Feature>
</List> </List>
{tier.price && {tier.prices && props.interval === SubscriptionInterval.MONTH &&
<Typography variant="subtitle1" sx={{fontWeight: 500}}> <Typography variant="body2" color="gray">
{tier.price} / month {t("account_upgrade_dialog_tier_price_billed_monthly", { price: formatPrice(tier.prices.month*12) })}
</Typography>
}
{tier.prices && props.interval === SubscriptionInterval.YEAR &&
<Typography variant="body2" color="gray">
{t("account_upgrade_dialog_tier_price_billed_yearly", { price: formatPrice(tier.prices.year), save: formatPrice(tier.prices.month*12-tier.prices.year) })}
</Typography> </Typography>
} }
</CardContent> </CardContent>
@ -231,16 +283,25 @@ const TierCard = (props) => {
); );
} }
const Feature = (props) => {
return <FeatureItem feature={true}>{props.children}</FeatureItem>;
}
const NoFeature = (props) => {
return <FeatureItem feature={false}>{props.children}</FeatureItem>;
}
const FeatureItem = (props) => { const FeatureItem = (props) => {
return ( return (
<ListItem disableGutters sx={{m: 0, p: 0}}> <ListItem disableGutters sx={{m: 0, p: 0}}>
<ListItemIcon sx={{minWidth: "24px"}}> <ListItemIcon sx={{minWidth: "24px"}}>
<Check fontSize="small" sx={{ color: "#338574" }}/> {props.feature && <Check fontSize="small" sx={{ color: "#338574" }}/>}
{!props.feature && <Close fontSize="small" sx={{ color: "gray" }}/>}
</ListItemIcon> </ListItemIcon>
<ListItemText <ListItemText
sx={{mt: "2px", mb: "2px"}} sx={{mt: "2px", mb: "2px"}}
primary={ primary={
<Typography variant="body2"> <Typography variant="body1">
{props.children} {props.children}
</Typography> </Typography>
} }