1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2025-06-02 03:20:34 +02:00

No more v.user races

This commit is contained in:
binwiederhier 2023-01-28 20:43:06 -05:00
parent e596834096
commit 92d563371c
5 changed files with 87 additions and 77 deletions

View file

@ -54,7 +54,7 @@ var (
)
// handleBillingTiersGet returns all available paid tiers, and the free tier. This is to populate the upgrade dialog
// in the UI. Note that this endpoint does NOT have a user context (no v.user!).
// in the UI. Note that this endpoint does NOT have a user context (no u!).
func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
tiers, err := s.userManager.Tiers()
if err != nil {
@ -107,7 +107,8 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
// handleAccountBillingSubscriptionCreate creates a Stripe checkout flow to create a user subscription. The tier
// will be updated by a subsequent webhook from Stripe, once the subscription becomes active.
func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user.Billing.StripeSubscriptionID != "" {
u := v.User()
if u.Billing.StripeSubscriptionID != "" {
return errHTTPBadRequestBillingSubscriptionExists
}
req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit, false)
@ -122,9 +123,9 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
}
log.Info("%s Creating Stripe checkout flow", logHTTPPrefix(v, r))
var stripeCustomerID *string
if v.user.Billing.StripeCustomerID != "" {
stripeCustomerID = &v.user.Billing.StripeCustomerID
stripeCustomer, err := s.stripe.GetCustomer(v.user.Billing.StripeCustomerID)
if u.Billing.StripeCustomerID != "" {
stripeCustomerID = &u.Billing.StripeCustomerID
stripeCustomer, err := s.stripe.GetCustomer(u.Billing.StripeCustomerID)
if err != nil {
return err
} else if stripeCustomer.Subscriptions != nil && len(stripeCustomer.Subscriptions.Data) > 0 {
@ -134,7 +135,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
successURL := s.config.BaseURL + apiAccountBillingSubscriptionCheckoutSuccessTemplate
params := &stripe.CheckoutSessionParams{
Customer: stripeCustomerID, // A user may have previously deleted their subscription
ClientReferenceID: &v.user.ID,
ClientReferenceID: &u.ID,
SuccessURL: &successURL,
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
AllowPromotionCodes: stripe.Bool(true),
@ -146,7 +147,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
},
Params: stripe.Params{
Metadata: map[string]string{
"user_id": v.user.ID,
"user_id": u.ID,
},
},
}
@ -164,7 +165,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
// the session ID in the URL to retrieve the Stripe subscription and update the local database. This is the first
// and only time we can map the local username with the Stripe customer ID.
func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, v *visitor) error {
// We don't have a v.user in this endpoint, only a userManager!
// We don't have v.User() in this endpoint, only a userManager!
matches := apiAccountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
if len(matches) != 2 {
return errHTTPInternalErrorInvalidPath
@ -212,7 +213,8 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
// handleAccountBillingSubscriptionUpdate updates an existing Stripe subscription to a new price, and updates
// a user's tier accordingly. This endpoint only works if there is an existing subscription.
func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user.Billing.StripeSubscriptionID == "" {
u := v.User()
if u.Billing.StripeSubscriptionID == "" {
return errNoBillingSubscription
}
req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit, false)
@ -223,8 +225,8 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
if err != nil {
return err
}
log.Info("%s Changing billing tier to %s (price %s) for subscription %s", logHTTPPrefix(v, r), tier.Code, tier.StripePriceID, v.user.Billing.StripeSubscriptionID)
sub, err := s.stripe.GetSubscription(v.user.Billing.StripeSubscriptionID)
log.Info("%s Changing billing tier to %s (price %s) for subscription %s", logHTTPPrefix(v, r), tier.Code, tier.StripePriceID, u.Billing.StripeSubscriptionID)
sub, err := s.stripe.GetSubscription(u.Billing.StripeSubscriptionID)
if err != nil {
return err
}
@ -248,12 +250,13 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
// and cancelling the Stripe subscription entirely
func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
log.Info("%s Deleting billing subscription %s", logHTTPPrefix(v, r), v.user.Billing.StripeSubscriptionID)
if v.user.Billing.StripeSubscriptionID != "" {
u := v.User()
log.Info("%s Deleting billing subscription %s", logHTTPPrefix(v, r), u.Billing.StripeSubscriptionID)
if u.Billing.StripeSubscriptionID != "" {
params := &stripe.SubscriptionParams{
CancelAtPeriodEnd: stripe.Bool(true),
}
_, err := s.stripe.UpdateSubscription(v.user.Billing.StripeSubscriptionID, params)
_, err := s.stripe.UpdateSubscription(u.Billing.StripeSubscriptionID, params)
if err != nil {
return err
}
@ -264,12 +267,13 @@ func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r
// handleAccountBillingPortalSessionCreate creates a session to the customer billing portal, and returns the
// redirect URL. The billing portal allows customers to change their payment methods, and cancel the subscription.
func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user.Billing.StripeCustomerID == "" {
u := v.User()
if u.Billing.StripeCustomerID == "" {
return errHTTPBadRequestNotAPaidUser
}
log.Info("%s Creating billing portal session", logHTTPPrefix(v, r))
params := &stripe.BillingPortalSessionParams{
Customer: stripe.String(v.user.Billing.StripeCustomerID),
Customer: stripe.String(u.Billing.StripeCustomerID),
ReturnURL: stripe.String(s.config.BaseURL),
}
ps, err := s.stripe.NewPortalSession(params)
@ -284,8 +288,8 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter,
// handleAccountBillingWebhook handles incoming Stripe webhooks. It mainly keeps the local user database in sync
// with the Stripe view of the world. This endpoint is authorized via the Stripe webhook secret. Note that the
// visitor (v) in this endpoint is the Stripe API, so we don't have v.user available.
func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Request, _ *visitor) error {
// visitor (v) in this endpoint is the Stripe API, so we don't have u available.
func (s *Server) handleAccountBillingWebhook(_ http.ResponseWriter, r *http.Request, _ *visitor) error {
stripeSignature := r.Header.Get("Stripe-Signature")
if stripeSignature == "" {
return errHTTPBadRequestBillingRequestInvalid