1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2025-06-01 11:09:30 +02:00

Payments webhook test, delete attachments/messages when reservations are removed,

This commit is contained in:
binwiederhier 2023-01-20 22:47:37 -05:00
parent 45b97c7054
commit 31a3bb7cd6
16 changed files with 571 additions and 157 deletions

View file

@ -27,6 +27,28 @@ var (
errNoBillingSubscription = errors.New("user does not have an active billing subscription")
)
// Payments in ntfy are done via Stripe.
//
// Pretty much all payments related things are in this file. The following processes
// handle payments:
//
// - Checkout:
// Creating a Stripe customer and subscription via the Checkout flow. This flow is only used if the
// ntfy user is not already a Stripe customer. This requires redirecting to the Stripe checkout page.
// It is implemented in handleAccountBillingSubscriptionCreate and the success callback
// handleAccountBillingSubscriptionCreateSuccess.
// - Update subscription:
// Switching between Stripe subscriptions (upgrade/downgrade) is handled via
// handleAccountBillingSubscriptionUpdate. This also handles proration.
// - Cancel subscription (at period end):
// Users can cancel the Stripe subscription via the web app at the end of the billing period. This
// simply updates the subscription and Stripe will cancel it. Users cannot immediately cancel the
// subscription.
// - Webhooks:
// Whenever a subscription changes (updated, deleted), Stripe sends us a request via a webhook.
// This is used to keep the local user database fields up to date. Stripe is the source of truth.
// What Stripe says is mirrored and not questioned.
// handleBillingTiersGet returns all available paid tiers, and the free tier. This is to populate the upgrade dialog
// in the UI. Note that this endpoint does NOT have a user context (no v.user!).
func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
@ -37,7 +59,7 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
freeTier := defaultVisitorLimits(s.config)
response := []*apiAccountBillingTier{
{
// Free tier: no code, name or price
// This is a bit of a hack: This is the "Free" tier. It has no tier code, name or price.
Limits: &apiAccountLimits{
Messages: freeTier.MessagesLimit,
MessagesExpiryDuration: int64(freeTier.MessagesExpiryDuration.Seconds()),
@ -130,6 +152,9 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
return s.writeJSON(w, response)
}
// handleAccountBillingSubscriptionCreateSuccess is called after the Stripe checkout session has succeeded. We use
// the session ID in the URL to retrieve the Stripe subscription and update the local database. This is the first
// and only time we can map the local username with the Stripe customer ID.
func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error {
// We don't have a v.user in this endpoint, only a userManager!
matches := apiAccountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
@ -139,8 +164,7 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
sessionID := matches[1]
sess, err := s.stripe.GetSession(sessionID) // FIXME How do we rate limit this?
if err != nil {
log.Warn("Stripe: %s", err)
return errHTTPBadRequestBillingRequestInvalid
return err
} else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" {
return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "customer or subscription not found")
}
@ -158,7 +182,7 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
if err != nil {
return err
}
if err := s.updateSubscriptionAndTier(u, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt, tier.Code); err != nil {
if err := s.updateSubscriptionAndTier(u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
return err
}
http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther)
@ -216,6 +240,8 @@ func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r
return s.writeJSON(w, newSuccessResponse())
}
// handleAccountBillingPortalSessionCreate creates a session to the customer billing portal, and returns the
// redirect URL. The billing portal allows customers to change their payment methods, and cancel the subscription.
func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user.Billing.StripeCustomerID == "" {
return errHTTPBadRequestNotAPaidUser
@ -250,10 +276,11 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ
}
event, err := s.stripe.ConstructWebhookEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey)
if err != nil {
return errHTTPBadRequestBillingRequestInvalid
return err
} else if event.Data == nil || event.Data.Raw == nil {
return errHTTPBadRequestBillingRequestInvalid
}
log.Info("Stripe: webhook event %s received", event.Type)
switch event.Type {
case "customer.subscription.updated":
@ -282,7 +309,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
if err != nil {
return err
}
if err := s.updateSubscriptionAndTier(u, r.Customer, subscriptionID, r.Status, r.CurrentPeriodEnd, r.CancelAt, tier.Code); err != nil {
if err := s.updateSubscriptionAndTier(u, tier, r.Customer, subscriptionID, r.Status, r.CurrentPeriodEnd, r.CancelAt); err != nil {
return err
}
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
@ -301,29 +328,54 @@ func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMe
if err != nil {
return err
}
if err := s.updateSubscriptionAndTier(u, r.Customer, "", "", 0, 0, ""); err != nil {
if err := s.updateSubscriptionAndTier(u, nil, r.Customer, "", "", 0, 0); err != nil {
return err
}
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
return nil
}
func (s *Server) updateSubscriptionAndTier(u *user.User, customerID, subscriptionID, status string, paidUntil, cancelAt int64, tier string) error {
u.Billing.StripeCustomerID = customerID
u.Billing.StripeSubscriptionID = subscriptionID
u.Billing.StripeSubscriptionStatus = stripe.SubscriptionStatus(status)
u.Billing.StripeSubscriptionPaidUntil = time.Unix(paidUntil, 0)
u.Billing.StripeSubscriptionCancelAt = time.Unix(cancelAt, 0)
if tier == "" {
func (s *Server) updateSubscriptionAndTier(u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
// Remove excess reservations (if too many for tier), and mark associated messages deleted
reservations, err := s.userManager.Reservations(u.Name)
if err != nil {
return err
}
reservationsLimit := visitorDefaultReservationsLimit
if tier != nil {
reservationsLimit = tier.ReservationsLimit
}
if int64(len(reservations)) > reservationsLimit {
topics := make([]string, 0)
for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- {
topics = append(topics, reservations[i].Topic)
}
if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil {
return err
}
if err := s.messageCache.ExpireMessages(topics...); err != nil {
return err
}
}
// Change or remove tier
if tier == nil {
if err := s.userManager.ResetTier(u.Name); err != nil {
return err
}
} else {
if err := s.userManager.ChangeTier(u.Name, tier); err != nil {
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
return err
}
}
if err := s.userManager.ChangeBilling(u); err != nil {
// Update billing fields
billing := &user.Billing{
StripeCustomerID: customerID,
StripeSubscriptionID: subscriptionID,
StripeSubscriptionStatus: stripe.SubscriptionStatus(status),
StripeSubscriptionPaidUntil: time.Unix(paidUntil, 0),
StripeSubscriptionCancelAt: time.Unix(cancelAt, 0),
}
if err := s.userManager.ChangeBilling(u.Name, billing); err != nil {
return err
}
return nil