diff --git a/cmd/tier.go b/cmd/tier.go index 1c3ede9f..2c06e98d 100644 --- a/cmd/tier.go +++ b/cmd/tier.go @@ -54,7 +54,8 @@ var cmdTier = &cli.Command{ &cli.StringFlag{Name: "attachment-total-size-limit", Value: defaultAttachmentTotalSizeLimit, Usage: "total size limit of attachments for the user"}, &cli.DurationFlag{Name: "attachment-expiry-duration", Value: defaultAttachmentExpiryDuration, Usage: "duration after which attachments are deleted"}, &cli.StringFlag{Name: "attachment-bandwidth-limit", Value: defaultAttachmentBandwidthLimit, Usage: "daily bandwidth limit for attachment uploads/downloads"}, - &cli.StringFlag{Name: "stripe-price-id", Usage: "Stripe price ID for paid tiers (e.g. price_12345)"}, + &cli.StringFlag{Name: "stripe-monthly-price-id", Usage: "Monthly Stripe price ID for paid tiers (e.g. price_12345)"}, + &cli.StringFlag{Name: "stripe-yearly-price-id", Usage: "Yearly Stripe price ID for paid tiers (e.g. price_12345)"}, &cli.BoolFlag{Name: "ignore-exists", Usage: "if the tier already exists, perform no action and exit"}, }, Description: `Add a new tier to the ntfy user database. @@ -96,7 +97,8 @@ Examples: &cli.StringFlag{Name: "attachment-total-size-limit", Usage: "total size limit of attachments for the user"}, &cli.DurationFlag{Name: "attachment-expiry-duration", Usage: "duration after which attachments are deleted"}, &cli.StringFlag{Name: "attachment-bandwidth-limit", Usage: "daily bandwidth limit for attachment uploads/downloads"}, - &cli.StringFlag{Name: "stripe-price-id", Usage: "Stripe price ID for paid tiers (e.g. price_12345)"}, + &cli.StringFlag{Name: "stripe-monthly-price-id", Usage: "Monthly Stripe price ID for paid tiers (e.g. price_12345)"}, + &cli.StringFlag{Name: "stripe-yearly-price-id", Usage: "Yearly Stripe price ID for paid tiers (e.g. price_12345)"}, }, Description: `Updates a tier to change the limits. @@ -110,7 +112,8 @@ Examples: ntfy tier change --name="Pro" pro # Update the name of an existing tier ntfy tier change \ # Update multiple limits and fields --message-expiry-duration=24h \ - --stripe-price-id=price_1234 \ + --stripe-monthly-price-id=price_1234 \ + --stripe-monthly-price-id=price_5678 \ pro `, }, @@ -166,6 +169,10 @@ func execTierAdd(c *cli.Context) error { return errors.New("tier code expected, type 'ntfy tier add --help' for help") } else if !user.AllowedTier(code) { return errors.New("tier code must consist only of numbers and letters") + } else if c.String("stripe-monthly-price-id") != "" && c.String("stripe-yearly-price-id") == "" { + return errors.New("if stripe-monthly-price-id is set, stripe-yearly-price-id must also be set") + } else if c.String("stripe-monthly-price-id") == "" && c.String("stripe-yearly-price-id") != "" { + return errors.New("if stripe-yearly-price-id is set, stripe-monthly-price-id must also be set") } manager, err := createUserManager(c) if err != nil { @@ -206,7 +213,8 @@ func execTierAdd(c *cli.Context) error { AttachmentTotalSizeLimit: attachmentTotalSizeLimit, AttachmentExpiryDuration: c.Duration("attachment-expiry-duration"), AttachmentBandwidthLimit: attachmentBandwidthLimit, - StripePriceID: c.String("stripe-price-id"), + StripeMonthlyPriceID: c.String("stripe-monthly-price-id"), + StripeYearlyPriceID: c.String("stripe-yearly-price-id"), } if err := manager.AddTier(tier); err != nil { return err @@ -273,8 +281,16 @@ func execTierChange(c *cli.Context) error { return err } } - if c.IsSet("stripe-price-id") { - tier.StripePriceID = c.String("stripe-price-id") + if c.IsSet("stripe-monthly-price-id") { + tier.StripeMonthlyPriceID = c.String("stripe-monthly-price-id") + } + if c.IsSet("stripe-yearly-price-id") { + tier.StripeYearlyPriceID = c.String("stripe-yearly-price-id") + } + if tier.StripeMonthlyPriceID != "" && tier.StripeYearlyPriceID == "" { + return errors.New("if stripe-monthly-price-id is set, stripe-yearly-price-id must also be set") + } else if tier.StripeMonthlyPriceID == "" && tier.StripeYearlyPriceID != "" { + return errors.New("if stripe-yearly-price-id is set, stripe-monthly-price-id must also be set") } if err := manager.UpdateTier(tier); err != nil { return err @@ -319,9 +335,9 @@ func execTierList(c *cli.Context) error { } func printTier(c *cli.Context, tier *user.Tier) { - stripePriceID := tier.StripePriceID - if stripePriceID == "" { - stripePriceID = "(none)" + prices := "(none)" + if tier.StripeMonthlyPriceID != "" && tier.StripeYearlyPriceID != "" { + prices = fmt.Sprintf("%s / %s", tier.StripeMonthlyPriceID, tier.StripeYearlyPriceID) } fmt.Fprintf(c.App.ErrWriter, "tier %s (id: %s)\n", tier.Code, tier.ID) fmt.Fprintf(c.App.ErrWriter, "- Name: %s\n", tier.Name) @@ -333,5 +349,5 @@ func printTier(c *cli.Context, tier *user.Tier) { fmt.Fprintf(c.App.ErrWriter, "- Attachment total size limit: %s\n", util.FormatSize(tier.AttachmentTotalSizeLimit)) fmt.Fprintf(c.App.ErrWriter, "- Attachment expiry duration: %s (%d seconds)\n", tier.AttachmentExpiryDuration.String(), int64(tier.AttachmentExpiryDuration.Seconds())) fmt.Fprintf(c.App.ErrWriter, "- Attachment daily bandwidth limit: %s\n", util.FormatSize(tier.AttachmentBandwidthLimit)) - fmt.Fprintf(c.App.ErrWriter, "- Stripe price: %s\n", stripePriceID) + fmt.Fprintf(c.App.ErrWriter, "- Stripe prices (monthly/yearly): %s\n", prices) } diff --git a/cmd/tier_test.go b/cmd/tier_test.go index 12343e4b..a447bfbc 100644 --- a/cmd/tier_test.go +++ b/cmd/tier_test.go @@ -36,7 +36,8 @@ func TestCLI_Tier_AddListChangeDelete(t *testing.T) { "--attachment-expiry-duration=7h", "--attachment-total-size-limit=10G", "--attachment-bandwidth-limit=100G", - "--stripe-price-id=price_991", + "--stripe-monthly-price-id=price_991", + "--stripe-yearly-price-id=price_992", "pro", )) require.Contains(t, stderr.String(), "- Message limit: 999") @@ -46,7 +47,7 @@ func TestCLI_Tier_AddListChangeDelete(t *testing.T) { require.Contains(t, stderr.String(), "- Attachment file size limit: 100.0 MB") require.Contains(t, stderr.String(), "- Attachment expiry duration: 7h") require.Contains(t, stderr.String(), "- Attachment total size limit: 10.0 GB") - require.Contains(t, stderr.String(), "- Stripe price: price_991") + require.Contains(t, stderr.String(), "- Stripe prices (monthly/yearly): price_991 / price_992") app, _, _, stderr = newTestApp() require.Nil(t, runTierCommand(app, conf, "remove", "pro")) diff --git a/docs/releases.md b/docs/releases.md index a2cd1d05..2354f8ee 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -8,6 +8,7 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release * Support for publishing to protected topics via email with access tokens ([#612](https://github.com/binwiederhier/ntfy/pull/621), thanks to [@tamcore](https://github.com/tamcore)) * Support for base64-encoded and nested multipart emails ([#610](https://github.com/binwiederhier/ntfy/issues/610), thanks to [@Robert-litts](https://github.com/Robert-litts)) +* Add support for annual billing intervals (no ticket) **Bug fixes + maintenance:** diff --git a/server/server.go b/server/server.go index 1890f179..c4e07238 100644 --- a/server/server.go +++ b/server/server.go @@ -45,11 +45,11 @@ type Server struct { visitors map[string]*visitor // ip: or user: firebaseClient *firebaseClient messages int64 - userManager *user.Manager // Might be nil! - messageCache *messageCache // Database that stores the messages - fileCache *fileCache // File system based cache that stores attachments - stripe stripeAPI // Stripe API, can be replaced with a mock - priceCache *util.LookupCache[map[string]string] // Stripe price ID -> formatted price + userManager *user.Manager // Might be nil! + messageCache *messageCache // Database that stores the messages + fileCache *fileCache // File system based cache that stores attachments + stripe stripeAPI // Stripe API, can be replaced with a mock + priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!) closeChan chan bool mu sync.Mutex } diff --git a/server/server_account.go b/server/server_account.go index 26ef7216..1b2c0ce4 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -100,6 +100,7 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, r *http.Request, v *vis Customer: true, Subscription: u.Billing.StripeSubscriptionID != "", Status: string(u.Billing.StripeSubscriptionStatus), + Interval: string(u.Billing.StripeSubscriptionInterval), PaidUntil: u.Billing.StripeSubscriptionPaidUntil.Unix(), CancelAt: u.Billing.StripeSubscriptionCancelAt.Unix(), } diff --git a/server/server_payments.go b/server/server_payments.go index d812837f..583a253f 100644 --- a/server/server_payments.go +++ b/server/server_payments.go @@ -80,14 +80,17 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ return err } for _, tier := range tiers { - priceStr, ok := prices[tier.StripePriceID] - if tier.StripePriceID == "" || !ok { + priceMonth, priceYear := prices[tier.StripeMonthlyPriceID], prices[tier.StripeYearlyPriceID] + if priceMonth == 0 || priceYear == 0 { // Only allow tiers that have both prices! continue } response = append(response, &apiAccountBillingTier{ - Code: tier.Code, - Name: tier.Name, - Price: priceStr, + Code: tier.Code, + Name: tier.Name, + Prices: &apiAccountBillingPrices{ + Month: priceMonth, + Year: priceYear, + }, Limits: &apiAccountLimits{ Basis: string(visitorLimitBasisTier), Messages: tier.MessageLimit, @@ -117,11 +120,21 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r tier, err := s.userManager.Tier(req.Tier) if err != nil { return err - } else if tier.StripePriceID == "" { + } + var priceID string + if req.Interval == string(stripe.PriceRecurringIntervalMonth) && tier.StripeMonthlyPriceID != "" { + priceID = tier.StripeMonthlyPriceID + } else if req.Interval == string(stripe.PriceRecurringIntervalYear) && tier.StripeYearlyPriceID != "" { + priceID = tier.StripeYearlyPriceID + } else { return errNotAPaidTier } logvr(v, r). With(tier). + Fields(log.Context{ + "stripe_price_id": priceID, + "stripe_subscription_interval": req.Interval, + }). Tag(tagStripe). Info("Creating Stripe checkout flow") var stripeCustomerID *string @@ -143,7 +156,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r AllowPromotionCodes: stripe.Bool(true), LineItems: []*stripe.CheckoutSessionLineItemParams{ { - Price: stripe.String(tier.StripePriceID), + Price: stripe.String(priceID), Quantity: stripe.Int64(1), }, }, @@ -180,10 +193,11 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr sub, err := s.stripe.GetSubscription(sess.Subscription.ID) if err != nil { return err - } else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil { + } else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil || sub.Items.Data[0].Price.Recurring == nil { return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "more than one line item in existing subscription") } - tier, err := s.userManager.TierByStripePrice(sub.Items.Data[0].Price.ID) + priceID, interval := sub.Items.Data[0].Price.ID, sub.Items.Data[0].Price.Recurring.Interval + tier, err := s.userManager.TierByStripePrice(priceID) if err != nil { return err } @@ -197,8 +211,10 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr Tag(tagStripe). Fields(log.Context{ "stripe_customer_id": sess.Customer.ID, + "stripe_price_id": priceID, "stripe_subscription_id": sub.ID, "stripe_subscription_status": string(sub.Status), + "stripe_subscription_interval": string(interval), "stripe_subscription_paid_until": sub.CurrentPeriodEnd, }). Info("Stripe checkout flow succeeded, updating user tier and subscription") @@ -213,7 +229,7 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr if _, err := s.stripe.UpdateCustomer(sess.Customer.ID, customerParams); err != nil { return err } - if err := s.updateSubscriptionAndTier(r, v, u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil { + if err := s.updateSubscriptionAndTier(r, v, u, tier, sess.Customer.ID, sub.ID, string(sub.Status), string(interval), sub.CurrentPeriodEnd, sub.CancelAt); err != nil { return err } http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther) @@ -235,15 +251,24 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r if err != nil { return err } + var priceID string + if req.Interval == string(stripe.PriceRecurringIntervalMonth) && tier.StripeMonthlyPriceID != "" { + priceID = tier.StripeMonthlyPriceID + } else if req.Interval == string(stripe.PriceRecurringIntervalYear) && tier.StripeYearlyPriceID != "" { + priceID = tier.StripeYearlyPriceID + } else { + return errNotAPaidTier + } logvr(v, r). Tag(tagStripe). Fields(log.Context{ - "new_tier_id": tier.ID, - "new_tier_name": tier.Name, - "new_tier_stripe_price_id": tier.StripePriceID, + "new_tier_id": tier.ID, + "new_tier_code": tier.Code, + "new_tier_stripe_price_id": priceID, + "new_tier_stripe_subscription_interval": req.Interval, // Other stripe_* fields filled by visitor context }). - Info("Changing Stripe subscription and billing tier to %s/%s (price %s)", tier.ID, tier.Name, tier.StripePriceID) + Info("Changing Stripe subscription and billing tier to %s/%s (price %s, %s)", tier.ID, tier.Name, priceID, req.Interval) sub, err := s.stripe.GetSubscription(u.Billing.StripeSubscriptionID) if err != nil { return err @@ -252,11 +277,11 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r } params := &stripe.SubscriptionParams{ CancelAtPeriodEnd: stripe.Bool(false), - ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)), + ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorAlwaysInvoice)), Items: []*stripe.SubscriptionItemsParams{ { ID: stripe.String(sub.Items.Data[0].ID), - Price: stripe.String(tier.StripePriceID), + Price: stripe.String(priceID), }, }, } @@ -345,20 +370,22 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(r *http.Request, ev, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event.Data.Raw))) if err != nil { return err - } else if ev.ID == "" || ev.Customer == "" || ev.Status == "" || ev.CurrentPeriodEnd == 0 || ev.Items == nil || len(ev.Items.Data) != 1 || ev.Items.Data[0].Price == nil || ev.Items.Data[0].Price.ID == "" { + } else if ev.ID == "" || ev.Customer == "" || ev.Status == "" || ev.CurrentPeriodEnd == 0 || ev.Items == nil || len(ev.Items.Data) != 1 || ev.Items.Data[0].Price == nil || ev.Items.Data[0].Price.ID == "" || ev.Items.Data[0].Price.Recurring == nil { + logvr(v, r).Tag(tagStripe).Field("stripe_request", fmt.Sprintf("%#v", ev)).Warn("Unexpected request from Stripe") return errHTTPBadRequestBillingRequestInvalid } - subscriptionID, priceID := ev.ID, ev.Items.Data[0].Price.ID + subscriptionID, priceID, interval := ev.ID, ev.Items.Data[0].Price.ID, ev.Items.Data[0].Price.Recurring.Interval logvr(v, r). Tag(tagStripe). Fields(log.Context{ "stripe_webhook_type": event.Type, "stripe_customer_id": ev.Customer, + "stripe_price_id": priceID, "stripe_subscription_id": ev.ID, "stripe_subscription_status": ev.Status, + "stripe_subscription_interval": interval, "stripe_subscription_paid_until": ev.CurrentPeriodEnd, "stripe_subscription_cancel_at": ev.CancelAt, - "stripe_price_id": priceID, }). Info("Updating subscription to status %s, with price %s", ev.Status, priceID) userFn := func() (*user.User, error) { @@ -376,7 +403,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(r *http.Request, if err != nil { return err } - if err := s.updateSubscriptionAndTier(r, v, u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil { + if err := s.updateSubscriptionAndTier(r, v, u, tier, ev.Customer, subscriptionID, ev.Status, string(interval), ev.CurrentPeriodEnd, ev.CancelAt); err != nil { return err } s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u)) @@ -399,14 +426,14 @@ func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(r *http.Request, Tag(tagStripe). Field("stripe_webhook_type", event.Type). Info("Subscription deleted, downgrading to unpaid tier") - if err := s.updateSubscriptionAndTier(r, v, u, nil, ev.Customer, "", "", 0, 0); err != nil { + if err := s.updateSubscriptionAndTier(r, v, u, nil, ev.Customer, "", "", "", 0, 0); err != nil { return err } s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u)) return nil } -func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error { +func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.User, tier *user.Tier, customerID, subscriptionID, status, interval string, paidUntil, cancelAt int64) error { reservationsLimit := visitorDefaultReservationsLimit if tier != nil { reservationsLimit = tier.ReservationLimit @@ -423,9 +450,8 @@ func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user. logvr(v, r). Tag(tagStripe). Fields(log.Context{ - "new_tier_id": tier.ID, - "new_tier_name": tier.Name, - "new_tier_stripe_price_id": tier.StripePriceID, + "new_tier_id": tier.ID, + "new_tier_code": tier.Code, }). Info("Changing tier to tier %s (%s) for user %s", tier.ID, tier.Name, u.Name) if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil { @@ -437,6 +463,7 @@ func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user. StripeCustomerID: customerID, StripeSubscriptionID: subscriptionID, StripeSubscriptionStatus: stripe.SubscriptionStatus(status), + StripeSubscriptionInterval: stripe.PriceRecurringInterval(interval), StripeSubscriptionPaidUntil: time.Unix(paidUntil, 0), StripeSubscriptionCancelAt: time.Unix(cancelAt, 0), } @@ -448,20 +475,16 @@ func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user. // fetchStripePrices contacts the Stripe API to retrieve all prices. This is used by the server to cache the prices // in memory, and ultimately for the web app to display the price table. -func (s *Server) fetchStripePrices() (map[string]string, error) { +func (s *Server) fetchStripePrices() (map[string]int64, error) { log.Debug("Caching prices from Stripe API") - priceMap := make(map[string]string) + priceMap := make(map[string]int64) prices, err := s.stripe.ListPrices(&stripe.PriceListParams{Active: stripe.Bool(true)}) if err != nil { log.Warn("Fetching Stripe prices failed: %s", err.Error()) return nil, err } for _, p := range prices { - if p.UnitAmount%100 == 0 { - priceMap[p.ID] = fmt.Sprintf("$%d", p.UnitAmount/100) - } else { - priceMap[p.ID] = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100) - } + priceMap[p.ID] = p.UnitAmount log.Trace("- Caching price %s = %v", p.ID, priceMap[p.ID]) } return priceMap, nil diff --git a/server/server_payments_test.go b/server/server_payments_test.go index 4640a728..e11af953 100644 --- a/server/server_payments_test.go +++ b/server/server_payments_test.go @@ -37,7 +37,9 @@ func TestPayments_Tiers(t *testing.T) { On("ListPrices", mock.Anything). Return([]*stripe.Price{ {ID: "price_123", UnitAmount: 500}, + {ID: "price_124", UnitAmount: 5000}, {ID: "price_456", UnitAmount: 1000}, + {ID: "price_457", UnitAmount: 10000}, {ID: "price_999", UnitAmount: 9999}, }, nil) @@ -58,7 +60,8 @@ func TestPayments_Tiers(t *testing.T) { AttachmentFileSizeLimit: 999, AttachmentTotalSizeLimit: 888, AttachmentExpiryDuration: time.Minute, - StripePriceID: "price_123", + StripeMonthlyPriceID: "price_123", + StripeYearlyPriceID: "price_124", })) require.Nil(t, s.userManager.AddTier(&user.Tier{ ID: "ti_444", @@ -71,7 +74,8 @@ func TestPayments_Tiers(t *testing.T) { AttachmentFileSizeLimit: 999111, AttachmentTotalSizeLimit: 888111, AttachmentExpiryDuration: time.Hour, - StripePriceID: "price_456", + StripeMonthlyPriceID: "price_456", + StripeYearlyPriceID: "price_457", })) response := request(t, s, "GET", "/v1/tiers", "", nil) require.Equal(t, 200, response.Code) @@ -98,6 +102,8 @@ func TestPayments_Tiers(t *testing.T) { require.Equal(t, "pro", tier.Code) require.Equal(t, "Pro", tier.Name) require.Equal(t, "tier", tier.Limits.Basis) + require.Equal(t, int64(500), tier.Prices.Month) + require.Equal(t, int64(5000), tier.Prices.Year) require.Equal(t, int64(777), tier.Limits.Reservations) require.Equal(t, int64(1000), tier.Limits.Messages) require.Equal(t, int64(3600), tier.Limits.MessagesExpiryDuration) @@ -109,6 +115,8 @@ func TestPayments_Tiers(t *testing.T) { tier = tiers[2] require.Equal(t, "business", tier.Code) require.Equal(t, "Business", tier.Name) + require.Equal(t, int64(1000), tier.Prices.Month) + require.Equal(t, int64(10000), tier.Prices.Year) require.Equal(t, "tier", tier.Limits.Basis) require.Equal(t, int64(777333), tier.Limits.Reservations) require.Equal(t, int64(2000), tier.Limits.Messages) @@ -136,14 +144,14 @@ func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) { // Create tier and user require.Nil(t, s.userManager.AddTier(&user.Tier{ - ID: "ti_123", - Code: "pro", - StripePriceID: "price_123", + ID: "ti_123", + Code: "pro", + StripeMonthlyPriceID: "price_123", })) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) // Create subscription - response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{ + response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro", "interval": "month"}`, map[string]string{ "Authorization": util.BasicAuth("phil", "phil"), }) require.Equal(t, 200, response.Code) @@ -172,9 +180,9 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) { // Create tier and user require.Nil(t, s.userManager.AddTier(&user.Tier{ - ID: "ti_123", - Code: "pro", - StripePriceID: "price_123", + ID: "ti_123", + Code: "pro", + StripeMonthlyPriceID: "price_123", })) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) @@ -187,7 +195,7 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) { require.Nil(t, s.userManager.ChangeBilling(u.Name, billing)) // Create subscription - response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{ + response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro", "interval": "month"}`, map[string]string{ "Authorization": util.BasicAuth("phil", "phil"), }) require.Equal(t, 200, response.Code) @@ -214,9 +222,9 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) { // Create tier and user require.Nil(t, s.userManager.AddTier(&user.Tier{ - ID: "ti_123", - Code: "pro", - StripePriceID: "price_123", + ID: "ti_123", + Code: "pro", + StripeMonthlyPriceID: "price_123", })) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) @@ -267,7 +275,7 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes require.Nil(t, s.userManager.AddTier(&user.Tier{ ID: "ti_123", Code: "starter", - StripePriceID: "price_1234", + StripeMonthlyPriceID: "price_1234", ReservationLimit: 1, MessageLimit: 220, // 220 * 5% = 11 requests before rate limiting kicks in MessageExpiryDuration: time.Hour, @@ -298,7 +306,12 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes Items: &stripe.SubscriptionItemList{ Data: []*stripe.SubscriptionItem{ { - Price: &stripe.Price{ID: "price_1234"}, + Price: &stripe.Price{ + ID: "price_1234", + Recurring: &stripe.PriceRecurring{ + Interval: stripe.PriceRecurringIntervalMonth, + }, + }, }, }, }, @@ -333,6 +346,7 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes require.Equal(t, "", u.Billing.StripeCustomerID) require.Equal(t, "", u.Billing.StripeSubscriptionID) require.Equal(t, stripe.SubscriptionStatus(""), u.Billing.StripeSubscriptionStatus) + require.Equal(t, stripe.PriceRecurringInterval(""), u.Billing.StripeSubscriptionInterval) require.Equal(t, int64(0), u.Billing.StripeSubscriptionPaidUntil.Unix()) require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix()) require.Equal(t, int64(0), u.Stats.Messages) // Messages and emails are not persisted for no-tier users! @@ -349,6 +363,7 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes require.Equal(t, "acct_5555", u.Billing.StripeCustomerID) require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID) require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) + require.Equal(t, stripe.PriceRecurringIntervalMonth, u.Billing.StripeSubscriptionInterval) require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix()) require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix()) require.Equal(t, int64(0), u.Stats.Messages) @@ -423,7 +438,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active( require.Nil(t, s.userManager.AddTier(&user.Tier{ ID: "ti_1", Code: "starter", - StripePriceID: "price_1234", // ! + StripeMonthlyPriceID: "price_1234", // ! ReservationLimit: 1, // ! MessageLimit: 100, MessageExpiryDuration: time.Hour, @@ -435,7 +450,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active( require.Nil(t, s.userManager.AddTier(&user.Tier{ ID: "ti_2", Code: "pro", - StripePriceID: "price_1111", // ! + StripeMonthlyPriceID: "price_1111", // ! ReservationLimit: 3, // ! MessageLimit: 200, MessageExpiryDuration: time.Hour, @@ -457,6 +472,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active( StripeCustomerID: "acct_5555", StripeSubscriptionID: "sub_1234", StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue, + StripeSubscriptionInterval: stripe.PriceRecurringIntervalMonth, StripeSubscriptionPaidUntil: time.Unix(123, 0), StripeSubscriptionCancelAt: time.Unix(456, 0), } @@ -499,9 +515,10 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active( require.Equal(t, "starter", u.Tier.Code) // Not "pro" require.Equal(t, "acct_5555", u.Billing.StripeCustomerID) require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID) - require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) // Not "past_due" - require.Equal(t, int64(1674268231), u.Billing.StripeSubscriptionPaidUntil.Unix()) // Updated - require.Equal(t, int64(1674299999), u.Billing.StripeSubscriptionCancelAt.Unix()) // Updated + require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) // Not "past_due" + require.Equal(t, stripe.PriceRecurringIntervalYear, u.Billing.StripeSubscriptionInterval) // Not "month" + require.Equal(t, int64(1674268231), u.Billing.StripeSubscriptionPaidUntil.Unix()) // Updated + require.Equal(t, int64(1674299999), u.Billing.StripeSubscriptionCancelAt.Unix()) // Updated // Verify that reservations were deleted r, err := s.userManager.Reservations("phil") @@ -546,10 +563,10 @@ func TestPayments_Webhook_Subscription_Deleted(t *testing.T) { // Create a user with a Stripe subscription and 3 reservations require.Nil(t, s.userManager.AddTier(&user.Tier{ - ID: "ti_1", - Code: "pro", - StripePriceID: "price_1234", - ReservationLimit: 1, + ID: "ti_1", + Code: "pro", + StripeMonthlyPriceID: "price_1234", + ReservationLimit: 1, })) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) require.Nil(t, s.userManager.ChangeTier("phil", "pro")) @@ -562,6 +579,7 @@ func TestPayments_Webhook_Subscription_Deleted(t *testing.T) { StripeCustomerID: "acct_5555", StripeSubscriptionID: "sub_1234", StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue, + StripeSubscriptionInterval: stripe.PriceRecurringIntervalMonth, StripeSubscriptionPaidUntil: time.Unix(123, 0), StripeSubscriptionCancelAt: time.Unix(0, 0), })) @@ -615,11 +633,11 @@ func TestPayments_Subscription_Update_Different_Tier(t *testing.T) { stripeMock. On("UpdateSubscription", "sub_123", &stripe.SubscriptionParams{ CancelAtPeriodEnd: stripe.Bool(false), - ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)), + ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorAlwaysInvoice)), Items: []*stripe.SubscriptionItemsParams{ { ID: stripe.String("someid_123"), - Price: stripe.String("price_456"), + Price: stripe.String("price_457"), }, }, }). @@ -627,14 +645,16 @@ func TestPayments_Subscription_Update_Different_Tier(t *testing.T) { // Create tier and user require.Nil(t, s.userManager.AddTier(&user.Tier{ - ID: "ti_123", - Code: "pro", - StripePriceID: "price_123", + ID: "ti_123", + Code: "pro", + StripeMonthlyPriceID: "price_123", + StripeYearlyPriceID: "price_124", })) require.Nil(t, s.userManager.AddTier(&user.Tier{ - ID: "ti_456", - Code: "business", - StripePriceID: "price_456", + ID: "ti_456", + Code: "business", + StripeMonthlyPriceID: "price_456", + StripeYearlyPriceID: "price_457", })) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) require.Nil(t, s.userManager.ChangeTier("phil", "pro")) @@ -644,7 +664,7 @@ func TestPayments_Subscription_Update_Different_Tier(t *testing.T) { })) // Call endpoint to change subscription - rr := request(t, s, "PUT", "/v1/account/billing/subscription", `{"tier":"business"}`, map[string]string{ + rr := request(t, s, "PUT", "/v1/account/billing/subscription", `{"tier":"business","interval":"year"}`, map[string]string{ "Authorization": util.BasicAuth("phil", "phil"), }) require.Equal(t, 200, rr.Code) @@ -795,7 +815,10 @@ const subscriptionUpdatedEventJSON = ` "data": [ { "price": { - "id": "price_1234" + "id": "price_1234", + "recurring": { + "interval": "year" + } } } ] @@ -818,7 +841,10 @@ const subscriptionDeletedEventJSON = ` "data": [ { "price": { - "id": "price_1234" + "id": "price_1234", + "recurring": { + "interval": "month" + } } } ] diff --git a/server/types.go b/server/types.go index c6331359..ead753dd 100644 --- a/server/types.go +++ b/server/types.go @@ -309,6 +309,7 @@ type apiAccountBilling struct { Customer bool `json:"customer"` Subscription bool `json:"subscription"` Status string `json:"status,omitempty"` + Interval string `json:"interval,omitempty"` PaidUntil int64 `json:"paid_until,omitempty"` CancelAt int64 `json:"cancel_at,omitempty"` } @@ -343,11 +344,16 @@ type apiConfigResponse struct { DisallowedTopics []string `json:"disallowed_topics"` } +type apiAccountBillingPrices struct { + Month int64 `json:"month"` + Year int64 `json:"year"` +} + type apiAccountBillingTier struct { - Code string `json:"code,omitempty"` - Name string `json:"name,omitempty"` - Price string `json:"price,omitempty"` - Limits *apiAccountLimits `json:"limits"` + Code string `json:"code,omitempty"` + Name string `json:"name,omitempty"` + Prices *apiAccountBillingPrices `json:"prices,omitempty"` + Limits *apiAccountLimits `json:"limits"` } type apiAccountBillingSubscriptionCreateResponse struct { @@ -355,7 +361,8 @@ type apiAccountBillingSubscriptionCreateResponse struct { } type apiAccountBillingSubscriptionChangeRequest struct { - Tier string `json:"tier"` + Tier string `json:"tier"` + Interval string `json:"interval"` } type apiAccountBillingPortalRedirectResponse struct { @@ -385,7 +392,10 @@ type apiStripeSubscriptionUpdatedEvent struct { Items *struct { Data []*struct { Price *struct { - ID string `json:"id"` + ID string `json:"id"` + Recurring *struct { + Interval string `json:"interval"` + } `json:"recurring"` } `json:"price"` } `json:"data"` } `json:"items"` diff --git a/user/manager.go b/user/manager.go index bb0dc3f3..58a8f4c7 100644 --- a/user/manager.go +++ b/user/manager.go @@ -46,7 +46,8 @@ var ( // Manager-related queries const ( - createTablesQueriesNoTx = ` + createTablesQueries = ` + BEGIN; CREATE TABLE IF NOT EXISTS tier ( id TEXT PRIMARY KEY, code TEXT NOT NULL, @@ -59,10 +60,12 @@ const ( attachment_total_size_limit INT NOT NULL, attachment_expiry_duration INT NOT NULL, attachment_bandwidth_limit INT NOT NULL, - stripe_price_id TEXT + stripe_monthly_price_id TEXT, + stripe_yearly_price_id TEXT ); CREATE UNIQUE INDEX idx_tier_code ON tier (code); - CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id); + CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id); + CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id); CREATE TABLE IF NOT EXISTS user ( id TEXT PRIMARY KEY, tier_id TEXT, @@ -76,6 +79,7 @@ const ( stripe_customer_id TEXT, stripe_subscription_id TEXT, stripe_subscription_status TEXT, + stripe_subscription_interval TEXT, stripe_subscription_paid_until INT, stripe_subscription_cancel_at INT, created INT NOT NULL, @@ -112,33 +116,33 @@ const ( INSERT INTO user (id, user, pass, role, sync_topic, created) VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', UNIXEPOCH()) ON CONFLICT (id) DO NOTHING; + COMMIT; ` - createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;` builtinStartupQueries = ` PRAGMA foreign_keys = ON; ` selectUserByIDQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id FROM user u LEFT JOIN tier t on t.id = u.tier_id WHERE u.id = ? ` selectUserByNameQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id FROM user u LEFT JOIN tier t on t.id = u.tier_id WHERE user = ? ` selectUserByTokenQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id FROM user u JOIN user_token tk on u.id = tk.user_id LEFT JOIN tier t on t.id = u.tier_id WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?) ` selectUserByStripeCustomerIDQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id FROM user u LEFT JOIN tier t on t.id = u.tier_id WHERE u.stripe_customer_id = ? @@ -246,27 +250,27 @@ const ( ` insertTierQuery = ` - INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ` updateTierQuery = ` UPDATE tier - SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_price_id = ? + SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_monthly_price_id = ?, stripe_yearly_price_id = ? WHERE code = ? ` selectTiersQuery = ` - SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id + SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id FROM tier ` selectTierByCodeQuery = ` - SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id + SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id FROM tier WHERE code = ? ` selectTierByPriceIDQuery = ` - SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id + SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id FROM tier - WHERE stripe_price_id = ? + WHERE (stripe_monthly_price_id = ? OR stripe_yearly_price_id = ?) ` updateUserTierQuery = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?` deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?` @@ -274,21 +278,86 @@ const ( updateBillingQuery = ` UPDATE user - SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ? + SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_interval = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ? WHERE user = ? ` ) // Schema management queries const ( - currentSchemaVersion = 2 + currentSchemaVersion = 3 insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)` updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1` selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1` // 1 -> 2 (complex migration!) - migrate1To2RenameUserTableQueryNoTx = ` + migrate1To2CreateTablesQueries = ` ALTER TABLE user RENAME TO user_old; + CREATE TABLE IF NOT EXISTS tier ( + id TEXT PRIMARY KEY, + code TEXT NOT NULL, + name TEXT NOT NULL, + messages_limit INT NOT NULL, + messages_expiry_duration INT NOT NULL, + emails_limit INT NOT NULL, + reservations_limit INT NOT NULL, + attachment_file_size_limit INT NOT NULL, + attachment_total_size_limit INT NOT NULL, + attachment_expiry_duration INT NOT NULL, + attachment_bandwidth_limit INT NOT NULL, + stripe_price_id TEXT + ); + CREATE UNIQUE INDEX idx_tier_code ON tier (code); + CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id); + CREATE TABLE IF NOT EXISTS user ( + id TEXT PRIMARY KEY, + tier_id TEXT, + user TEXT NOT NULL, + pass TEXT NOT NULL, + role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL, + prefs JSON NOT NULL DEFAULT '{}', + sync_topic TEXT NOT NULL, + stats_messages INT NOT NULL DEFAULT (0), + stats_emails INT NOT NULL DEFAULT (0), + stripe_customer_id TEXT, + stripe_subscription_id TEXT, + stripe_subscription_status TEXT, + stripe_subscription_paid_until INT, + stripe_subscription_cancel_at INT, + created INT NOT NULL, + deleted INT, + FOREIGN KEY (tier_id) REFERENCES tier (id) + ); + CREATE UNIQUE INDEX idx_user ON user (user); + CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id); + CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id); + CREATE TABLE IF NOT EXISTS user_access ( + user_id TEXT NOT NULL, + topic TEXT NOT NULL, + read INT NOT NULL, + write INT NOT NULL, + owner_user_id INT, + PRIMARY KEY (user_id, topic), + FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE, + FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE + ); + CREATE TABLE IF NOT EXISTS user_token ( + user_id TEXT NOT NULL, + token TEXT NOT NULL, + label TEXT NOT NULL, + last_access INT NOT NULL, + last_origin TEXT NOT NULL, + expires INT NOT NULL, + PRIMARY KEY (user_id, token), + FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE + ); + CREATE TABLE IF NOT EXISTS schemaVersion ( + id INT PRIMARY KEY, + version INT NOT NULL + ); + INSERT INTO user (id, user, pass, role, sync_topic, created) + VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH()) + ON CONFLICT (id) DO NOTHING; ` migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old` migrate1To2InsertUserNoTx = ` @@ -304,11 +373,22 @@ const ( DROP TABLE access; DROP TABLE user_old; ` + + // 2 -> 3 + migrate2To3UpdateQueries = ` + ALTER TABLE user ADD COLUMN stripe_subscription_interval TEXT; + ALTER TABLE tier RENAME COLUMN stripe_price_id TO stripe_monthly_price_id; + ALTER TABLE tier ADD COLUMN stripe_yearly_price_id TEXT; + DROP INDEX IF EXISTS idx_tier_price_id; + CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id); + CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id); + ` ) var ( migrations = map[int]func(db *sql.DB) error{ 1: migrateFrom1, + 2: migrateFrom2, } ) @@ -805,13 +885,13 @@ func (a *Manager) userByToken(token string) (*User, error) { func (a *Manager) readUser(rows *sql.Rows) (*User, error) { defer rows.Close() var id, username, hash, role, prefs, syncTopic string - var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierID, tierCode, tierName sql.NullString + var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString var messages, emails int64 var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64 if !rows.Next() { return nil, ErrUserNotFound } - if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil { + if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err @@ -828,11 +908,12 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) { Emails: emails, }, Billing: &Billing{ - StripeCustomerID: stripeCustomerID.String, // May be empty - StripeSubscriptionID: stripeSubscriptionID.String, // May be empty - StripeSubscriptionStatus: stripe.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty - StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero - StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero + StripeCustomerID: stripeCustomerID.String, // May be empty + StripeSubscriptionID: stripeSubscriptionID.String, // May be empty + StripeSubscriptionStatus: stripe.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty + StripeSubscriptionInterval: stripe.PriceRecurringInterval(stripeSubscriptionInterval.String), // May be empty + StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero + StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero }, Deleted: deleted.Valid, } @@ -853,7 +934,8 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) { AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64, - StripePriceID: stripePriceID.String, // May be empty + StripeMonthlyPriceID: stripeMonthlyPriceID.String, // May be empty + StripeYearlyPriceID: stripeYearlyPriceID.String, // May be empty } } return user, nil @@ -1134,7 +1216,7 @@ func (a *Manager) AddTier(tier *Tier) error { if tier.ID == "" { tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength) } - if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripePriceID)); err != nil { + if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID)); err != nil { return err } return nil @@ -1142,7 +1224,7 @@ func (a *Manager) AddTier(tier *Tier) error { // UpdateTier updates a tier's properties in the database func (a *Manager) UpdateTier(tier *Tier) error { - if _, err := a.db.Exec(updateTierQuery, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripePriceID), tier.Code); err != nil { + if _, err := a.db.Exec(updateTierQuery, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID), tier.Code); err != nil { return err } return nil @@ -1162,7 +1244,7 @@ func (a *Manager) RemoveTier(code string) error { // ChangeBilling updates a user's billing fields, namely the Stripe customer ID, and subscription information func (a *Manager) ChangeBilling(username string, billing *Billing) error { - if _, err := a.db.Exec(updateBillingQuery, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil { + if _, err := a.db.Exec(updateBillingQuery, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullString(string(billing.StripeSubscriptionInterval)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil { return err } return nil @@ -1200,7 +1282,7 @@ func (a *Manager) Tier(code string) (*Tier, error) { // TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) { - rows, err := a.db.Query(selectTierByPriceIDQuery, priceID) + rows, err := a.db.Query(selectTierByPriceIDQuery, priceID, priceID) if err != nil { return nil, err } @@ -1210,12 +1292,12 @@ func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) { func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) { var id, code, name string - var stripePriceID sql.NullString + var stripeMonthlyPriceID, stripeYearlyPriceID sql.NullString var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64 if !rows.Next() { return nil, ErrTierNotFound } - if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil { + if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err @@ -1233,7 +1315,8 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) { AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64, - StripePriceID: stripePriceID.String, // May be empty + StripeMonthlyPriceID: stripeMonthlyPriceID.String, // May be empty + StripeYearlyPriceID: stripeYearlyPriceID.String, // May be empty }, nil } @@ -1313,10 +1396,7 @@ func migrateFrom1(db *sql.DB) error { } defer tx.Rollback() // Rename user -> user_old, and create new tables - if _, err := tx.Exec(migrate1To2RenameUserTableQueryNoTx); err != nil { - return err - } - if _, err := tx.Exec(createTablesQueriesNoTx); err != nil { + if _, err := tx.Exec(migrate1To2CreateTablesQueries); err != nil { return err } // Insert users from user_old into new user table, with ID and sync_topic @@ -1356,6 +1436,22 @@ func migrateFrom1(db *sql.DB) error { return nil } +func migrateFrom2(db *sql.DB) error { + log.Tag(tag).Info("Migrating user database schema: from 2 to 3") + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(migrate2To3UpdateQueries); err != nil { + return err + } + if _, err := tx.Exec(updateSchemaVersion, 3); err != nil { + return err + } + return tx.Commit() +} + func nullString(s string) sql.NullString { if s == "" { return sql.NullString{} diff --git a/user/manager_test.go b/user/manager_test.go index f809b5a9..f242af71 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -4,6 +4,7 @@ import ( "database/sql" "fmt" "github.com/stretchr/testify/require" + "github.com/stripe/stripe-go/v74" "golang.org/x/crypto/bcrypt" "heckel.io/ntfy/util" "net/netip" @@ -113,7 +114,8 @@ func TestManager_AddUser_And_Query(t *testing.T) { require.Nil(t, a.ChangeBilling("user", &Billing{ StripeCustomerID: "acct_123", StripeSubscriptionID: "sub_123", - StripeSubscriptionStatus: "active", + StripeSubscriptionStatus: stripe.SubscriptionStatusActive, + StripeSubscriptionInterval: stripe.PriceRecurringIntervalMonth, StripeSubscriptionPaidUntil: time.Now().Add(time.Hour), StripeSubscriptionCancelAt: time.Unix(0, 0), })) @@ -395,7 +397,7 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) { require.Nil(t, a.AddTier(&Tier{ Code: "pro", Name: "ntfy Pro", - StripePriceID: "price123", + StripeMonthlyPriceID: "price123", MessageLimit: 5_000, MessageExpiryDuration: 3 * 24 * time.Hour, EmailLimit: 50, @@ -761,7 +763,7 @@ func TestManager_Tier_Create_Update_List_Delete(t *testing.T) { AttachmentTotalSizeLimit: 1, AttachmentExpiryDuration: time.Second, AttachmentBandwidthLimit: 1, - StripePriceID: "price_1", + StripeMonthlyPriceID: "price_1", })) require.Nil(t, a.AddTier(&Tier{ Code: "pro", @@ -774,7 +776,7 @@ func TestManager_Tier_Create_Update_List_Delete(t *testing.T) { AttachmentTotalSizeLimit: 123123, AttachmentExpiryDuration: 10800 * time.Second, AttachmentBandwidthLimit: 21474836480, - StripePriceID: "price_2", + StripeMonthlyPriceID: "price_2", })) require.Nil(t, a.AddUser("phil", "phil", RoleUser)) require.Nil(t, a.ChangeTier("phil", "pro")) @@ -800,7 +802,7 @@ func TestManager_Tier_Create_Update_List_Delete(t *testing.T) { require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit) require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration) require.Equal(t, int64(21474836480), ti.AttachmentBandwidthLimit) - require.Equal(t, "price_2", ti.StripePriceID) + require.Equal(t, "price_2", ti.StripeMonthlyPriceID) // Update tier ti.EmailLimit = 999999 @@ -822,7 +824,7 @@ func TestManager_Tier_Create_Update_List_Delete(t *testing.T) { require.Equal(t, int64(1), ti.AttachmentTotalSizeLimit) require.Equal(t, time.Second, ti.AttachmentExpiryDuration) require.Equal(t, int64(1), ti.AttachmentBandwidthLimit) - require.Equal(t, "price_1", ti.StripePriceID) + require.Equal(t, "price_1", ti.StripeMonthlyPriceID) ti = tiers[1] require.Equal(t, "pro", ti.Code) @@ -835,7 +837,7 @@ func TestManager_Tier_Create_Update_List_Delete(t *testing.T) { require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit) require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration) require.Equal(t, int64(21474836480), ti.AttachmentBandwidthLimit) - require.Equal(t, "price_2", ti.StripePriceID) + require.Equal(t, "price_2", ti.StripeMonthlyPriceID) ti, err = a.TierByStripePrice("price_1") require.Nil(t, err) @@ -849,7 +851,7 @@ func TestManager_Tier_Create_Update_List_Delete(t *testing.T) { require.Equal(t, int64(1), ti.AttachmentTotalSizeLimit) require.Equal(t, time.Second, ti.AttachmentExpiryDuration) require.Equal(t, int64(1), ti.AttachmentBandwidthLimit) - require.Equal(t, "price_1", ti.StripePriceID) + require.Equal(t, "price_1", ti.StripeMonthlyPriceID) // Cannot remove tier, since user has this tier require.Error(t, a.RemoveTier("pro")) diff --git a/user/types.go b/user/types.go index 0363c97f..2486f110 100644 --- a/user/types.go +++ b/user/types.go @@ -91,15 +91,17 @@ type Tier struct { AttachmentTotalSizeLimit int64 // Total file size for all files of this user (bytes) AttachmentExpiryDuration time.Duration // Duration after which attachments will be deleted AttachmentBandwidthLimit int64 // Daily bandwidth limit for the user - StripePriceID string // Price ID for paid tiers (price_...) + StripeMonthlyPriceID string // Monthly price ID for paid tiers (price_...) + StripeYearlyPriceID string // Yearly price ID for paid tiers (price_...) } // Context returns fields for the log func (t *Tier) Context() log.Context { return log.Context{ - "tier_id": t.ID, - "tier_code": t.Code, - "stripe_price_id": t.StripePriceID, + "tier_id": t.ID, + "tier_code": t.Code, + "stripe_monthly_price_id": t.StripeMonthlyPriceID, + "stripe_yearly_price_id": t.StripeYearlyPriceID, } } @@ -136,6 +138,7 @@ type Billing struct { StripeCustomerID string StripeSubscriptionID string StripeSubscriptionStatus stripe.SubscriptionStatus + StripeSubscriptionInterval stripe.PriceRecurringInterval StripeSubscriptionPaidUntil time.Time StripeSubscriptionCancelAt time.Time } diff --git a/user/types_test.go b/user/types_test.go index 22dd6c7b..811d33f2 100644 --- a/user/types_test.go +++ b/user/types_test.go @@ -49,12 +49,15 @@ func TestAllowedTier(t *testing.T) { func TestTierContext(t *testing.T) { tier := &Tier{ - ID: "ti_abc", - Code: "pro", - StripePriceID: "price_123", + ID: "ti_abc", + Code: "pro", + StripeMonthlyPriceID: "price_123", + StripeYearlyPriceID: "price_456", } context := tier.Context() require.Equal(t, "ti_abc", context["tier_id"]) require.Equal(t, "pro", context["tier_code"]) - require.Equal(t, "price_123", context["stripe_price_id"]) + require.Equal(t, "price_123", context["stripe_monthly_price_id"]) + require.Equal(t, "price_456", context["stripe_yearly_price_id"]) + } diff --git a/web/public/static/langs/en.json b/web/public/static/langs/en.json index 04f98e40..efd877d3 100644 --- a/web/public/static/langs/en.json +++ b/web/public/static/langs/en.json @@ -193,6 +193,8 @@ "account_basics_tier_admin_suffix_no_tier": "(no tier)", "account_basics_tier_basic": "Basic", "account_basics_tier_free": "Free", + "account_basics_tier_interval_monthly": "monthly", + "account_basics_tier_interval_yearly": "annually", "account_basics_tier_upgrade_button": "Upgrade to Pro", "account_basics_tier_change_button": "Change", "account_basics_tier_paid_until": "Subscription paid until {{date}}, and will auto-renew", @@ -215,15 +217,21 @@ "account_delete_dialog_button_submit": "Permanently delete account", "account_delete_dialog_billing_warning": "Deleting your account also cancels your billing subscription immediately. You will not have access to the billing dashboard anymore.", "account_upgrade_dialog_title": "Change account tier", + "account_upgrade_dialog_interval_monthly": "Monthly", + "account_upgrade_dialog_interval_yearly": "Annually", "account_upgrade_dialog_cancel_warning": "This will cancel your subscription, and downgrade your account on {{date}}. On that date, topic reservations as well as messages cached on the server will be deleted.", - "account_upgrade_dialog_proration_info": "Proration: When switching between paid plans, the price difference will be charged or refunded in the next invoice. You will not receive another invoice until the end of the next billing period.", + "account_upgrade_dialog_proration_info": "Proration: When upgrading between paid plans, the price difference will be charged immediately. When downgrading to a lower tier, the balance will be used to pay for future billing periods.", "account_upgrade_dialog_reservations_warning_one": "The selected tier allows fewer reserved topics than your current tier. Before changing your tier, please delete at least one reservation. You can remove reservations in the Settings.", "account_upgrade_dialog_reservations_warning_other": "The selected tier allows fewer reserved topics than your current tier. Before changing your tier, please delete at least {{count}} reservations. You can remove reservations in the Settings.", "account_upgrade_dialog_tier_features_reservations": "{{reservations}} reserved topics", + "account_upgrade_dialog_tier_features_no_reservations": "No reserved topics", "account_upgrade_dialog_tier_features_messages": "{{messages}} daily messages", "account_upgrade_dialog_tier_features_emails": "{{emails}} daily emails", "account_upgrade_dialog_tier_features_attachment_file_size": "{{filesize}} per file", "account_upgrade_dialog_tier_features_attachment_total_size": "{{totalsize}} total storage", + "account_upgrade_dialog_tier_price_per_month": "month", + "account_upgrade_dialog_tier_price_billed_monthly": "{{price}} per year. Billed monthly.", + "account_upgrade_dialog_tier_price_billed_yearly": "{{price}} billed annually. Save {{save}}.", "account_upgrade_dialog_tier_selected_label": "Selected", "account_upgrade_dialog_tier_current_label": "Current", "account_upgrade_dialog_button_cancel": "Cancel", diff --git a/web/src/app/AccountApi.js b/web/src/app/AccountApi.js index 6382d1fa..243286b4 100644 --- a/web/src/app/AccountApi.js +++ b/web/src/app/AccountApi.js @@ -257,23 +257,24 @@ class AccountApi { return this.tiers; } - async createBillingSubscription(tier) { - console.log(`[AccountApi] Creating billing subscription with ${tier}`); - return await this.upsertBillingSubscription("POST", tier) + async createBillingSubscription(tier, interval) { + console.log(`[AccountApi] Creating billing subscription with ${tier} and interval ${interval}`); + return await this.upsertBillingSubscription("POST", tier, interval) } - async updateBillingSubscription(tier) { - console.log(`[AccountApi] Updating billing subscription with ${tier}`); - return await this.upsertBillingSubscription("PUT", tier) + async updateBillingSubscription(tier, interval) { + console.log(`[AccountApi] Updating billing subscription with ${tier} and interval ${interval}`); + return await this.upsertBillingSubscription("PUT", tier, interval) } - async upsertBillingSubscription(method, tier) { + async upsertBillingSubscription(method, tier, interval) { const url = accountBillingSubscriptionUrl(config.base_url); const response = await fetchOrThrow(url, { method: method, headers: withBearerAuth({}, session.token()), body: JSON.stringify({ - tier: tier + tier: tier, + interval: interval }) }); return await response.json(); // May throw SyntaxError @@ -371,6 +372,12 @@ export const SubscriptionStatus = { PAST_DUE: "past_due" }; +// Maps to stripe.PriceRecurringInterval +export const SubscriptionInterval = { + MONTH: "month", + YEAR: "year" +}; + // Maps to user.Permission in user/types.go export const Permission = { READ_WRITE: "read-write", diff --git a/web/src/app/utils.js b/web/src/app/utils.js index 88f67ce4..6eb4ac54 100644 --- a/web/src/app/utils.js +++ b/web/src/app/utils.js @@ -212,6 +212,13 @@ export const formatNumber = (n) => { return n; } +export const formatPrice = (n) => { + if (n % 100 === 0) { + return `$${n/100}`; + } + return `$${(n/100).toPrecision(2)}`; +} + export const openUrl = (url) => { window.open(url, "_blank", "noopener,noreferrer"); }; diff --git a/web/src/components/Account.js b/web/src/components/Account.js index 224999b6..e5b60077 100644 --- a/web/src/components/Account.js +++ b/web/src/components/Account.js @@ -35,7 +35,7 @@ import TextField from "@mui/material/TextField"; import routes from "./routes"; import IconButton from "@mui/material/IconButton"; import {formatBytes, formatShortDate, formatShortDateTime, openUrl} from "../app/utils"; -import accountApi, {LimitBasis, Role, SubscriptionStatus} from "../app/AccountApi"; +import accountApi, {LimitBasis, Role, SubscriptionInterval, SubscriptionStatus} from "../app/AccountApi"; import InfoOutlinedIcon from '@mui/icons-material/InfoOutlined'; import {Pref, PrefGroup} from "./Pref"; import db from "../app/db"; @@ -248,6 +248,11 @@ const AccountType = () => { accountType = (config.enable_payments) ? t("account_basics_tier_free") : t("account_basics_tier_basic"); } else { accountType = account.tier.name; + if (account.billing?.interval === SubscriptionInterval.MONTH) { + accountType += ` (${t("account_basics_tier_interval_monthly")})`; + } else if (account.billing?.interval === SubscriptionInterval.YEAR) { + accountType += ` (${t("account_basics_tier_interval_yearly")})`; + } } return ( diff --git a/web/src/components/UpgradeDialog.js b/web/src/components/UpgradeDialog.js index 247131c3..1ec07a25 100644 --- a/web/src/components/UpgradeDialog.js +++ b/web/src/components/UpgradeDialog.js @@ -3,20 +3,20 @@ import {useContext, useEffect, useState} from 'react'; import Dialog from '@mui/material/Dialog'; import DialogContent from '@mui/material/DialogContent'; import DialogTitle from '@mui/material/DialogTitle'; -import {Alert, CardActionArea, CardContent, ListItem, useMediaQuery} from "@mui/material"; +import {Alert, Badge, CardActionArea, CardContent, Chip, ListItem, Stack, Switch, useMediaQuery} from "@mui/material"; import theme from "./theme"; import DialogFooter from "./DialogFooter"; import Button from "@mui/material/Button"; -import accountApi from "../app/AccountApi"; +import accountApi, {SubscriptionInterval} from "../app/AccountApi"; import session from "../app/Session"; import routes from "./routes"; import Card from "@mui/material/Card"; import Typography from "@mui/material/Typography"; import {AccountContext} from "./App"; -import {formatBytes, formatNumber, formatShortDate} from "../app/utils"; +import {formatBytes, formatNumber, formatPrice, formatShortDate} from "../app/utils"; import {Trans, useTranslation} from "react-i18next"; import List from "@mui/material/List"; -import {Check} from "@mui/icons-material"; +import {Check, Close} from "@mui/icons-material"; import ListItemIcon from "@mui/material/ListItemIcon"; import ListItemText from "@mui/material/ListItemText"; import Box from "@mui/material/Box"; @@ -28,6 +28,7 @@ const UpgradeDialog = (props) => { const { account } = useContext(AccountContext); // May be undefined! const [error, setError] = useState(""); const [tiers, setTiers] = useState(null); + const [interval, setInterval] = useState(account?.billing?.interval || SubscriptionInterval.YEAR); const [newTierCode, setNewTierCode] = useState(account?.tier?.code); // May be undefined const [loading, setLoading] = useState(false); const fullScreen = useMediaQuery(theme.breakpoints.down('sm')); @@ -46,6 +47,7 @@ const UpgradeDialog = (props) => { const tiersMap = Object.assign(...tiers.map(tier => ({[tier.code]: tier}))); const newTier = tiersMap[newTierCode]; // May be undefined const currentTier = account?.tier; // May be undefined + const currentInterval = account?.billing?.interval; // May be undefined const currentTierCode = currentTier?.code; // May be undefined // Figure out buttons, labels and the submit action @@ -54,7 +56,7 @@ const UpgradeDialog = (props) => { submitButtonLabel = t("account_upgrade_dialog_button_redirect_signup"); submitAction = Action.REDIRECT_SIGNUP; banner = null; - } else if (currentTierCode === newTierCode) { + } else if (currentTierCode === newTierCode && currentInterval === interval) { submitButtonLabel = t("account_upgrade_dialog_button_update_subscription"); submitAction = null; banner = (currentTierCode) ? Banner.PRORATION_INFO : null; @@ -88,10 +90,10 @@ const UpgradeDialog = (props) => { try { setLoading(true); if (submitAction === Action.CREATE_SUBSCRIPTION) { - const response = await accountApi.createBillingSubscription(newTierCode); + const response = await accountApi.createBillingSubscription(newTierCode, interval); window.location.href = response.redirect_url; } else if (submitAction === Action.UPDATE_SUBSCRIPTION) { - await accountApi.updateBillingSubscription(newTierCode); + await accountApi.updateBillingSubscription(newTierCode, interval); } else if (submitAction === Action.CANCEL_SUBSCRIPTION) { await accountApi.deleteBillingSubscription(); } @@ -108,15 +110,45 @@ const UpgradeDialog = (props) => { } } + // Figure out discount + let discount; + if (newTier?.prices) { + discount = Math.round(((newTier.prices.month*12/newTier.prices.year)-1)*100); + } else { + for (const t of tiers) { + if (t.prices) { + discount = Math.round(((t.prices.month*12/t.prices.year)-1)*100); + break; + } + } + } + return ( - {t("account_upgrade_dialog_title")} + +
+
{t("account_upgrade_dialog_title")}
+
+ {t("account_upgrade_dialog_interval_monthly")} + setInterval(ev.target.checked ? SubscriptionInterval.YEAR : SubscriptionInterval.MONTH)} + /> + {t("account_upgrade_dialog_interval_yearly")} + {discount > 0 && } +
+
+
{ tier={tier} current={currentTierCode === tier.code} // tier.code or currentTierCode may be undefined! selected={newTierCode === tier.code} // tier.code may be undefined! + interval={interval} onClick={() => setNewTierCode(tier.code)} // tier.code may be undefined! /> )}
{banner === Banner.CANCEL_WARNING && - + } {banner === Banner.PRORATION_INFO && - + } {banner === Banner.RESERVATIONS_WARNING && - + { const TierCard = (props) => { const { t } = useTranslation(); const tier = props.tier; + let cardStyle, labelStyle, labelText; if (props.selected) { - cardStyle = { background: "#eee", border: "2px solid #338574" }; + cardStyle = { background: "#eee", border: "3px solid #338574" }; labelStyle = { background: "#338574", color: "white" }; labelText = t("account_upgrade_dialog_tier_selected_label"); } else if (props.current) { - cardStyle = { border: "2px solid #eee" }; + cardStyle = { border: "3px solid #eee" }; labelStyle = { background: "#eee", color: "black" }; labelText = t("account_upgrade_dialog_tier_current_label"); } else { - cardStyle = { border: "2px solid transparent" }; + cardStyle = { border: "3px solid transparent" }; + } + + let monthlyPrice; + if (!tier.prices) { + monthlyPrice = 0; + } else if (props.interval === SubscriptionInterval.YEAR) { + monthlyPrice = tier.prices.year/12; + } else if (props.interval === SubscriptionInterval.MONTH) { + monthlyPrice = tier.prices.month; } return ( { ...labelStyle }}>{labelText} } - + {tier.name || t("account_basics_tier_free")} +
+ {formatPrice(monthlyPrice)} + {monthlyPrice > 0 && <>/ {t("account_upgrade_dialog_tier_price_per_month")}} +
- {tier.limits.reservations > 0 && {t("account_upgrade_dialog_tier_features_reservations", { reservations: tier.limits.reservations })}} - {t("account_upgrade_dialog_tier_features_messages", { messages: formatNumber(tier.limits.messages) })} - {t("account_upgrade_dialog_tier_features_emails", { emails: formatNumber(tier.limits.emails) })} - {t("account_upgrade_dialog_tier_features_attachment_file_size", { filesize: formatBytes(tier.limits.attachment_file_size, 0) })} - {t("account_upgrade_dialog_tier_features_attachment_total_size", { totalsize: formatBytes(tier.limits.attachment_total_size, 0) })} + {tier.limits.reservations > 0 && {t("account_upgrade_dialog_tier_features_reservations", { reservations: tier.limits.reservations })}} + {tier.limits.reservations === 0 && {t("account_upgrade_dialog_tier_features_no_reservations")}} + {t("account_upgrade_dialog_tier_features_messages", { messages: formatNumber(tier.limits.messages) })} + {t("account_upgrade_dialog_tier_features_emails", { emails: formatNumber(tier.limits.emails) })} + {t("account_upgrade_dialog_tier_features_attachment_file_size", { filesize: formatBytes(tier.limits.attachment_file_size, 0) })} + {t("account_upgrade_dialog_tier_features_attachment_total_size", { totalsize: formatBytes(tier.limits.attachment_total_size, 0) })} - {tier.price && - - {tier.price} / month + {tier.prices && props.interval === SubscriptionInterval.MONTH && + + {t("account_upgrade_dialog_tier_price_billed_monthly", { price: formatPrice(tier.prices.month*12) })} + + } + {tier.prices && props.interval === SubscriptionInterval.YEAR && + + {t("account_upgrade_dialog_tier_price_billed_yearly", { price: formatPrice(tier.prices.year), save: formatPrice(tier.prices.month*12-tier.prices.year) })} } @@ -231,16 +283,25 @@ const TierCard = (props) => { ); } +const Feature = (props) => { + return {props.children}; +} + +const NoFeature = (props) => { + return {props.children}; +} + const FeatureItem = (props) => { return ( - + {props.feature && } + {!props.feature && } + {props.children} }