1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2025-05-31 18:49:20 +02:00

Support for annual billing intervals

This commit is contained in:
binwiederhier 2023-02-21 22:44:30 -05:00
parent 07afaf961d
commit ef9d6d9f6c
17 changed files with 453 additions and 183 deletions

View file

@ -80,14 +80,17 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
return err
}
for _, tier := range tiers {
priceStr, ok := prices[tier.StripePriceID]
if tier.StripePriceID == "" || !ok {
priceMonth, priceYear := prices[tier.StripeMonthlyPriceID], prices[tier.StripeYearlyPriceID]
if priceMonth == 0 || priceYear == 0 { // Only allow tiers that have both prices!
continue
}
response = append(response, &apiAccountBillingTier{
Code: tier.Code,
Name: tier.Name,
Price: priceStr,
Code: tier.Code,
Name: tier.Name,
Prices: &apiAccountBillingPrices{
Month: priceMonth,
Year: priceYear,
},
Limits: &apiAccountLimits{
Basis: string(visitorLimitBasisTier),
Messages: tier.MessageLimit,
@ -117,11 +120,21 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
tier, err := s.userManager.Tier(req.Tier)
if err != nil {
return err
} else if tier.StripePriceID == "" {
}
var priceID string
if req.Interval == string(stripe.PriceRecurringIntervalMonth) && tier.StripeMonthlyPriceID != "" {
priceID = tier.StripeMonthlyPriceID
} else if req.Interval == string(stripe.PriceRecurringIntervalYear) && tier.StripeYearlyPriceID != "" {
priceID = tier.StripeYearlyPriceID
} else {
return errNotAPaidTier
}
logvr(v, r).
With(tier).
Fields(log.Context{
"stripe_price_id": priceID,
"stripe_subscription_interval": req.Interval,
}).
Tag(tagStripe).
Info("Creating Stripe checkout flow")
var stripeCustomerID *string
@ -143,7 +156,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
AllowPromotionCodes: stripe.Bool(true),
LineItems: []*stripe.CheckoutSessionLineItemParams{
{
Price: stripe.String(tier.StripePriceID),
Price: stripe.String(priceID),
Quantity: stripe.Int64(1),
},
},
@ -180,10 +193,11 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
sub, err := s.stripe.GetSubscription(sess.Subscription.ID)
if err != nil {
return err
} else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil {
} else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil || sub.Items.Data[0].Price.Recurring == nil {
return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "more than one line item in existing subscription")
}
tier, err := s.userManager.TierByStripePrice(sub.Items.Data[0].Price.ID)
priceID, interval := sub.Items.Data[0].Price.ID, sub.Items.Data[0].Price.Recurring.Interval
tier, err := s.userManager.TierByStripePrice(priceID)
if err != nil {
return err
}
@ -197,8 +211,10 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
Tag(tagStripe).
Fields(log.Context{
"stripe_customer_id": sess.Customer.ID,
"stripe_price_id": priceID,
"stripe_subscription_id": sub.ID,
"stripe_subscription_status": string(sub.Status),
"stripe_subscription_interval": string(interval),
"stripe_subscription_paid_until": sub.CurrentPeriodEnd,
}).
Info("Stripe checkout flow succeeded, updating user tier and subscription")
@ -213,7 +229,7 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
if _, err := s.stripe.UpdateCustomer(sess.Customer.ID, customerParams); err != nil {
return err
}
if err := s.updateSubscriptionAndTier(r, v, u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
if err := s.updateSubscriptionAndTier(r, v, u, tier, sess.Customer.ID, sub.ID, string(sub.Status), string(interval), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
return err
}
http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther)
@ -235,15 +251,24 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
if err != nil {
return err
}
var priceID string
if req.Interval == string(stripe.PriceRecurringIntervalMonth) && tier.StripeMonthlyPriceID != "" {
priceID = tier.StripeMonthlyPriceID
} else if req.Interval == string(stripe.PriceRecurringIntervalYear) && tier.StripeYearlyPriceID != "" {
priceID = tier.StripeYearlyPriceID
} else {
return errNotAPaidTier
}
logvr(v, r).
Tag(tagStripe).
Fields(log.Context{
"new_tier_id": tier.ID,
"new_tier_name": tier.Name,
"new_tier_stripe_price_id": tier.StripePriceID,
"new_tier_id": tier.ID,
"new_tier_code": tier.Code,
"new_tier_stripe_price_id": priceID,
"new_tier_stripe_subscription_interval": req.Interval,
// Other stripe_* fields filled by visitor context
}).
Info("Changing Stripe subscription and billing tier to %s/%s (price %s)", tier.ID, tier.Name, tier.StripePriceID)
Info("Changing Stripe subscription and billing tier to %s/%s (price %s, %s)", tier.ID, tier.Name, priceID, req.Interval)
sub, err := s.stripe.GetSubscription(u.Billing.StripeSubscriptionID)
if err != nil {
return err
@ -252,11 +277,11 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
}
params := &stripe.SubscriptionParams{
CancelAtPeriodEnd: stripe.Bool(false),
ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)),
ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorAlwaysInvoice)),
Items: []*stripe.SubscriptionItemsParams{
{
ID: stripe.String(sub.Items.Data[0].ID),
Price: stripe.String(tier.StripePriceID),
Price: stripe.String(priceID),
},
},
}
@ -345,20 +370,22 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(r *http.Request,
ev, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event.Data.Raw)))
if err != nil {
return err
} else if ev.ID == "" || ev.Customer == "" || ev.Status == "" || ev.CurrentPeriodEnd == 0 || ev.Items == nil || len(ev.Items.Data) != 1 || ev.Items.Data[0].Price == nil || ev.Items.Data[0].Price.ID == "" {
} else if ev.ID == "" || ev.Customer == "" || ev.Status == "" || ev.CurrentPeriodEnd == 0 || ev.Items == nil || len(ev.Items.Data) != 1 || ev.Items.Data[0].Price == nil || ev.Items.Data[0].Price.ID == "" || ev.Items.Data[0].Price.Recurring == nil {
logvr(v, r).Tag(tagStripe).Field("stripe_request", fmt.Sprintf("%#v", ev)).Warn("Unexpected request from Stripe")
return errHTTPBadRequestBillingRequestInvalid
}
subscriptionID, priceID := ev.ID, ev.Items.Data[0].Price.ID
subscriptionID, priceID, interval := ev.ID, ev.Items.Data[0].Price.ID, ev.Items.Data[0].Price.Recurring.Interval
logvr(v, r).
Tag(tagStripe).
Fields(log.Context{
"stripe_webhook_type": event.Type,
"stripe_customer_id": ev.Customer,
"stripe_price_id": priceID,
"stripe_subscription_id": ev.ID,
"stripe_subscription_status": ev.Status,
"stripe_subscription_interval": interval,
"stripe_subscription_paid_until": ev.CurrentPeriodEnd,
"stripe_subscription_cancel_at": ev.CancelAt,
"stripe_price_id": priceID,
}).
Info("Updating subscription to status %s, with price %s", ev.Status, priceID)
userFn := func() (*user.User, error) {
@ -376,7 +403,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(r *http.Request,
if err != nil {
return err
}
if err := s.updateSubscriptionAndTier(r, v, u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
if err := s.updateSubscriptionAndTier(r, v, u, tier, ev.Customer, subscriptionID, ev.Status, string(interval), ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
return err
}
s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
@ -399,14 +426,14 @@ func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(r *http.Request,
Tag(tagStripe).
Field("stripe_webhook_type", event.Type).
Info("Subscription deleted, downgrading to unpaid tier")
if err := s.updateSubscriptionAndTier(r, v, u, nil, ev.Customer, "", "", 0, 0); err != nil {
if err := s.updateSubscriptionAndTier(r, v, u, nil, ev.Customer, "", "", "", 0, 0); err != nil {
return err
}
s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
return nil
}
func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.User, tier *user.Tier, customerID, subscriptionID, status, interval string, paidUntil, cancelAt int64) error {
reservationsLimit := visitorDefaultReservationsLimit
if tier != nil {
reservationsLimit = tier.ReservationLimit
@ -423,9 +450,8 @@ func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.
logvr(v, r).
Tag(tagStripe).
Fields(log.Context{
"new_tier_id": tier.ID,
"new_tier_name": tier.Name,
"new_tier_stripe_price_id": tier.StripePriceID,
"new_tier_id": tier.ID,
"new_tier_code": tier.Code,
}).
Info("Changing tier to tier %s (%s) for user %s", tier.ID, tier.Name, u.Name)
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
@ -437,6 +463,7 @@ func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.
StripeCustomerID: customerID,
StripeSubscriptionID: subscriptionID,
StripeSubscriptionStatus: stripe.SubscriptionStatus(status),
StripeSubscriptionInterval: stripe.PriceRecurringInterval(interval),
StripeSubscriptionPaidUntil: time.Unix(paidUntil, 0),
StripeSubscriptionCancelAt: time.Unix(cancelAt, 0),
}
@ -448,20 +475,16 @@ func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.
// fetchStripePrices contacts the Stripe API to retrieve all prices. This is used by the server to cache the prices
// in memory, and ultimately for the web app to display the price table.
func (s *Server) fetchStripePrices() (map[string]string, error) {
func (s *Server) fetchStripePrices() (map[string]int64, error) {
log.Debug("Caching prices from Stripe API")
priceMap := make(map[string]string)
priceMap := make(map[string]int64)
prices, err := s.stripe.ListPrices(&stripe.PriceListParams{Active: stripe.Bool(true)})
if err != nil {
log.Warn("Fetching Stripe prices failed: %s", err.Error())
return nil, err
}
for _, p := range prices {
if p.UnitAmount%100 == 0 {
priceMap[p.ID] = fmt.Sprintf("$%d", p.UnitAmount/100)
} else {
priceMap[p.ID] = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100)
}
priceMap[p.ID] = p.UnitAmount
log.Trace("- Caching price %s = %v", p.ID, priceMap[p.ID])
}
return priceMap, nil