mirror of
				https://github.com/binwiederhier/ntfy.git
				synced 2025-10-31 13:02:24 +01:00 
			
		
		
		
	No more v.user races
This commit is contained in:
		
							parent
							
								
									e596834096
								
							
						
					
					
						commit
						92d563371c
					
				
					 5 changed files with 87 additions and 77 deletions
				
			
		|  | @ -39,7 +39,6 @@ import ( | ||||||
| - HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...) | - HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...) | ||||||
| - HIGH Stripe payment methods | - HIGH Stripe payment methods | ||||||
| - MEDIUM: Test new token endpoints & never-expiring token | - MEDIUM: Test new token endpoints & never-expiring token | ||||||
| - MEDIUM: Races with v.user (see publishSyncEventAsync test) |  | ||||||
| - MEDIUM: Test that anonymous user and user without tier are the same visitor | - MEDIUM: Test that anonymous user and user without tier are the same visitor | ||||||
| - MEDIUM: Make sure account endpoints make sense for admins | - MEDIUM: Make sure account endpoints make sense for admins | ||||||
| - MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben) | - MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben) | ||||||
|  |  | ||||||
|  | @ -19,11 +19,12 @@ const ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *visitor) error { | func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *visitor) error { | ||||||
| 	admin := v.user != nil && v.user.Role == user.RoleAdmin | 	u := v.User() | ||||||
|  | 	admin := u != nil && u.Role == user.RoleAdmin | ||||||
| 	if !admin { | 	if !admin { | ||||||
| 		if !s.config.EnableSignup { | 		if !s.config.EnableSignup { | ||||||
| 			return errHTTPBadRequestSignupNotEnabled | 			return errHTTPBadRequestSignupNotEnabled | ||||||
| 		} else if v.user != nil { | 		} else if u != nil { | ||||||
| 			return errHTTPUnauthorized // Cannot create account from user context | 			return errHTTPUnauthorized // Cannot create account from user context | ||||||
| 		} | 		} | ||||||
| 		if !v.AccountCreationAllowed() { | 		if !v.AccountCreationAllowed() { | ||||||
|  | @ -150,20 +151,21 @@ func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v * | ||||||
| 	} else if req.Password == "" { | 	} else if req.Password == "" { | ||||||
| 		return errHTTPBadRequest | 		return errHTTPBadRequest | ||||||
| 	} | 	} | ||||||
| 	if _, err := s.userManager.Authenticate(v.user.Name, req.Password); err != nil { | 	u := v.User() | ||||||
|  | 	if _, err := s.userManager.Authenticate(u.Name, req.Password); err != nil { | ||||||
| 		return errHTTPBadRequestIncorrectPasswordConfirmation | 		return errHTTPBadRequestIncorrectPasswordConfirmation | ||||||
| 	} | 	} | ||||||
| 	if v.user.Billing.StripeSubscriptionID != "" { | 	if u.Billing.StripeSubscriptionID != "" { | ||||||
| 		log.Info("%s Canceling billing subscription %s", logHTTPPrefix(v, r), v.user.Billing.StripeSubscriptionID) | 		log.Info("%s Canceling billing subscription %s", logHTTPPrefix(v, r), u.Billing.StripeSubscriptionID) | ||||||
| 		if _, err := s.stripe.CancelSubscription(v.user.Billing.StripeSubscriptionID); err != nil { | 		if _, err := s.stripe.CancelSubscription(u.Billing.StripeSubscriptionID); err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	if err := s.maybeRemoveMessagesAndExcessReservations(logHTTPPrefix(v, r), v.user, 0); err != nil { | 	if err := s.maybeRemoveMessagesAndExcessReservations(logHTTPPrefix(v, r), u, 0); err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	log.Info("%s Marking user %s as deleted", logHTTPPrefix(v, r), v.user.Name) | 	log.Info("%s Marking user %s as deleted", logHTTPPrefix(v, r), u.Name) | ||||||
| 	if err := s.userManager.MarkUserRemoved(v.user); err != nil { | 	if err := s.userManager.MarkUserRemoved(u); err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	return s.writeJSON(w, newSuccessResponse()) | 	return s.writeJSON(w, newSuccessResponse()) | ||||||
|  | @ -176,10 +178,11 @@ func (s *Server) handleAccountPasswordChange(w http.ResponseWriter, r *http.Requ | ||||||
| 	} else if req.Password == "" || req.NewPassword == "" { | 	} else if req.Password == "" || req.NewPassword == "" { | ||||||
| 		return errHTTPBadRequest | 		return errHTTPBadRequest | ||||||
| 	} | 	} | ||||||
| 	if _, err := s.userManager.Authenticate(v.user.Name, req.Password); err != nil { | 	u := v.User() | ||||||
|  | 	if _, err := s.userManager.Authenticate(u.Name, req.Password); err != nil { | ||||||
| 		return errHTTPBadRequestIncorrectPasswordConfirmation | 		return errHTTPBadRequestIncorrectPasswordConfirmation | ||||||
| 	} | 	} | ||||||
| 	if err := s.userManager.ChangePassword(v.user.Name, req.NewPassword); err != nil { | 	if err := s.userManager.ChangePassword(u.Name, req.NewPassword); err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	return s.writeJSON(w, newSuccessResponse()) | 	return s.writeJSON(w, newSuccessResponse()) | ||||||
|  | @ -267,10 +270,11 @@ func (s *Server) handleAccountSettingsChange(w http.ResponseWriter, r *http.Requ | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	if v.user.Prefs == nil { | 	u := v.User() | ||||||
| 		v.user.Prefs = &user.Prefs{} | 	if u.Prefs == nil { | ||||||
|  | 		u.Prefs = &user.Prefs{} | ||||||
| 	} | 	} | ||||||
| 	prefs := v.user.Prefs | 	prefs := u.Prefs | ||||||
| 	if newPrefs.Language != nil { | 	if newPrefs.Language != nil { | ||||||
| 		prefs.Language = newPrefs.Language | 		prefs.Language = newPrefs.Language | ||||||
| 	} | 	} | ||||||
|  | @ -288,7 +292,7 @@ func (s *Server) handleAccountSettingsChange(w http.ResponseWriter, r *http.Requ | ||||||
| 			prefs.Notification.MinPriority = newPrefs.Notification.MinPriority | 			prefs.Notification.MinPriority = newPrefs.Notification.MinPriority | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	if err := s.userManager.ChangeSettings(v.user); err != nil { | 	if err := s.userManager.ChangeSettings(u); err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	return s.writeJSON(w, newSuccessResponse()) | 	return s.writeJSON(w, newSuccessResponse()) | ||||||
|  | @ -299,11 +303,12 @@ func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Req | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	if v.user.Prefs == nil { | 	u := v.User() | ||||||
| 		v.user.Prefs = &user.Prefs{} | 	if u.Prefs == nil { | ||||||
|  | 		u.Prefs = &user.Prefs{} | ||||||
| 	} | 	} | ||||||
| 	newSubscription.ID = "" // Client cannot set ID | 	newSubscription.ID = "" // Client cannot set ID | ||||||
| 	for _, subscription := range v.user.Prefs.Subscriptions { | 	for _, subscription := range u.Prefs.Subscriptions { | ||||||
| 		if newSubscription.BaseURL == subscription.BaseURL && newSubscription.Topic == subscription.Topic { | 		if newSubscription.BaseURL == subscription.BaseURL && newSubscription.Topic == subscription.Topic { | ||||||
| 			newSubscription = subscription | 			newSubscription = subscription | ||||||
| 			break | 			break | ||||||
|  | @ -311,8 +316,8 @@ func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Req | ||||||
| 	} | 	} | ||||||
| 	if newSubscription.ID == "" { | 	if newSubscription.ID == "" { | ||||||
| 		newSubscription.ID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength) | 		newSubscription.ID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength) | ||||||
| 		v.user.Prefs.Subscriptions = append(v.user.Prefs.Subscriptions, newSubscription) | 		u.Prefs.Subscriptions = append(u.Prefs.Subscriptions, newSubscription) | ||||||
| 		if err := s.userManager.ChangeSettings(v.user); err != nil { | 		if err := s.userManager.ChangeSettings(u); err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | @ -329,11 +334,12 @@ func (s *Server) handleAccountSubscriptionChange(w http.ResponseWriter, r *http. | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	if v.user.Prefs == nil || v.user.Prefs.Subscriptions == nil { | 	u := v.User() | ||||||
|  | 	if u.Prefs == nil || u.Prefs.Subscriptions == nil { | ||||||
| 		return errHTTPNotFound | 		return errHTTPNotFound | ||||||
| 	} | 	} | ||||||
| 	var subscription *user.Subscription | 	var subscription *user.Subscription | ||||||
| 	for _, sub := range v.user.Prefs.Subscriptions { | 	for _, sub := range u.Prefs.Subscriptions { | ||||||
| 		if sub.ID == subscriptionID { | 		if sub.ID == subscriptionID { | ||||||
| 			sub.DisplayName = updatedSubscription.DisplayName | 			sub.DisplayName = updatedSubscription.DisplayName | ||||||
| 			subscription = sub | 			subscription = sub | ||||||
|  | @ -343,7 +349,7 @@ func (s *Server) handleAccountSubscriptionChange(w http.ResponseWriter, r *http. | ||||||
| 	if subscription == nil { | 	if subscription == nil { | ||||||
| 		return errHTTPNotFound | 		return errHTTPNotFound | ||||||
| 	} | 	} | ||||||
| 	if err := s.userManager.ChangeSettings(v.user); err != nil { | 	if err := s.userManager.ChangeSettings(u); err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	return s.writeJSON(w, subscription) | 	return s.writeJSON(w, subscription) | ||||||
|  | @ -355,18 +361,19 @@ func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http. | ||||||
| 		return errHTTPInternalErrorInvalidPath | 		return errHTTPInternalErrorInvalidPath | ||||||
| 	} | 	} | ||||||
| 	subscriptionID := matches[1] | 	subscriptionID := matches[1] | ||||||
| 	if v.user.Prefs == nil || v.user.Prefs.Subscriptions == nil { | 	u := v.User() | ||||||
|  | 	if u.Prefs == nil || u.Prefs.Subscriptions == nil { | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 	newSubscriptions := make([]*user.Subscription, 0) | 	newSubscriptions := make([]*user.Subscription, 0) | ||||||
| 	for _, subscription := range v.user.Prefs.Subscriptions { | 	for _, subscription := range u.Prefs.Subscriptions { | ||||||
| 		if subscription.ID != subscriptionID { | 		if subscription.ID != subscriptionID { | ||||||
| 			newSubscriptions = append(newSubscriptions, subscription) | 			newSubscriptions = append(newSubscriptions, subscription) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	if len(newSubscriptions) < len(v.user.Prefs.Subscriptions) { | 	if len(newSubscriptions) < len(u.Prefs.Subscriptions) { | ||||||
| 		v.user.Prefs.Subscriptions = newSubscriptions | 		u.Prefs.Subscriptions = newSubscriptions | ||||||
| 		if err := s.userManager.ChangeSettings(v.user); err != nil { | 		if err := s.userManager.ChangeSettings(u); err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | @ -374,7 +381,8 @@ func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http. | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Request, v *visitor) error { | func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Request, v *visitor) error { | ||||||
| 	if v.user != nil && v.user.Role == user.RoleAdmin { | 	u := v.User() | ||||||
|  | 	if u != nil && u.Role == user.RoleAdmin { | ||||||
| 		return errHTTPBadRequestMakesNoSenseForAdmin | 		return errHTTPBadRequestMakesNoSenseForAdmin | ||||||
| 	} | 	} | ||||||
| 	req, err := readJSONWithLimit[apiAccountReservationRequest](r.Body, jsonBodyBytesLimit, false) | 	req, err := readJSONWithLimit[apiAccountReservationRequest](r.Body, jsonBodyBytesLimit, false) | ||||||
|  | @ -388,27 +396,27 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errHTTPBadRequestPermissionInvalid | 		return errHTTPBadRequestPermissionInvalid | ||||||
| 	} | 	} | ||||||
| 	if v.user.Tier == nil { | 	if u.Tier == nil { | ||||||
| 		return errHTTPUnauthorized | 		return errHTTPUnauthorized | ||||||
| 	} | 	} | ||||||
| 	// CHeck if we are allowed to reserve this topic | 	// CHeck if we are allowed to reserve this topic | ||||||
| 	if err := s.userManager.CheckAllowAccess(v.user.Name, req.Topic); err != nil { | 	if err := s.userManager.CheckAllowAccess(u.Name, req.Topic); err != nil { | ||||||
| 		return errHTTPConflictTopicReserved | 		return errHTTPConflictTopicReserved | ||||||
| 	} | 	} | ||||||
| 	hasReservation, err := s.userManager.HasReservation(v.user.Name, req.Topic) | 	hasReservation, err := s.userManager.HasReservation(u.Name, req.Topic) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	if !hasReservation { | 	if !hasReservation { | ||||||
| 		reservations, err := s.userManager.ReservationsCount(v.user.Name) | 		reservations, err := s.userManager.ReservationsCount(u.Name) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} else if reservations >= v.user.Tier.ReservationLimit { | 		} else if reservations >= u.Tier.ReservationLimit { | ||||||
| 			return errHTTPTooManyRequestsLimitReservations | 			return errHTTPTooManyRequestsLimitReservations | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	// Actually add the reservation | 	// Actually add the reservation | ||||||
| 	if err := s.userManager.AddReservation(v.user.Name, req.Topic, everyone); err != nil { | 	if err := s.userManager.AddReservation(u.Name, req.Topic, everyone); err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	// Kill existing subscribers | 	// Kill existing subscribers | ||||||
|  | @ -416,7 +424,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	t.CancelSubscribers(v.user.ID) | 	t.CancelSubscribers(u.ID) | ||||||
| 	return s.writeJSON(w, newSuccessResponse()) | 	return s.writeJSON(w, newSuccessResponse()) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -429,13 +437,14 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R | ||||||
| 	if !topicRegex.MatchString(topic) { | 	if !topicRegex.MatchString(topic) { | ||||||
| 		return errHTTPBadRequestTopicInvalid | 		return errHTTPBadRequestTopicInvalid | ||||||
| 	} | 	} | ||||||
| 	authorized, err := s.userManager.HasReservation(v.user.Name, topic) | 	u := v.User() | ||||||
|  | 	authorized, err := s.userManager.HasReservation(u.Name, topic) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} else if !authorized { | 	} else if !authorized { | ||||||
| 		return errHTTPUnauthorized | 		return errHTTPUnauthorized | ||||||
| 	} | 	} | ||||||
| 	if err := s.userManager.RemoveReservations(v.user.Name, topic); err != nil { | 	if err := s.userManager.RemoveReservations(u.Name, topic); err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	return s.writeJSON(w, newSuccessResponse()) | 	return s.writeJSON(w, newSuccessResponse()) | ||||||
|  | @ -465,12 +474,23 @@ func (s *Server) maybeRemoveMessagesAndExcessReservations(logPrefix string, u *u | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // publishSyncEventAsync kicks of a Go routine to publish a sync message to the user's sync topic | ||||||
|  | func (s *Server) publishSyncEventAsync(v *visitor) { | ||||||
|  | 	go func() { | ||||||
|  | 		if err := s.publishSyncEvent(v); err != nil { | ||||||
|  | 			log.Trace("%s Error publishing to user's sync topic: %s", v.String(), err.Error()) | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // publishSyncEvent publishes a sync message to the user's sync topic | ||||||
| func (s *Server) publishSyncEvent(v *visitor) error { | func (s *Server) publishSyncEvent(v *visitor) error { | ||||||
| 	if v.user == nil || v.user.SyncTopic == "" { | 	u := v.User() | ||||||
|  | 	if u == nil || u.SyncTopic == "" { | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 	log.Trace("Publishing sync event to user %s's sync topic %s", v.user.Name, v.user.SyncTopic) | 	log.Trace("Publishing sync event to user %s's sync topic %s", u.Name, u.SyncTopic) | ||||||
| 	syncTopic, err := s.topicFromID(v.user.SyncTopic) | 	syncTopic, err := s.topicFromID(u.SyncTopic) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  | @ -484,15 +504,3 @@ func (s *Server) publishSyncEvent(v *visitor) error { | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 |  | ||||||
| func (s *Server) publishSyncEventAsync(v *visitor) { |  | ||||||
| 	go func() { |  | ||||||
| 		u := v.User() |  | ||||||
| 		if u == nil || u.SyncTopic == "" { |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 		if err := s.publishSyncEvent(v); err != nil { |  | ||||||
| 			log.Trace("Error publishing to user %s's sync topic %s: %s", u.Name, u.SyncTopic, err.Error()) |  | ||||||
| 		} |  | ||||||
| 	}() |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  | @ -24,7 +24,7 @@ func (s *Server) ensureUserManager(next handleFunc) handleFunc { | ||||||
| 
 | 
 | ||||||
| func (s *Server) ensureUser(next handleFunc) handleFunc { | func (s *Server) ensureUser(next handleFunc) handleFunc { | ||||||
| 	return s.ensureUserManager(func(w http.ResponseWriter, r *http.Request, v *visitor) error { | 	return s.ensureUserManager(func(w http.ResponseWriter, r *http.Request, v *visitor) error { | ||||||
| 		if v.user == nil { | 		if v.User() == nil { | ||||||
| 			return errHTTPUnauthorized | 			return errHTTPUnauthorized | ||||||
| 		} | 		} | ||||||
| 		return next(w, r, v) | 		return next(w, r, v) | ||||||
|  | @ -42,7 +42,7 @@ func (s *Server) ensurePaymentsEnabled(next handleFunc) handleFunc { | ||||||
| 
 | 
 | ||||||
| func (s *Server) ensureStripeCustomer(next handleFunc) handleFunc { | func (s *Server) ensureStripeCustomer(next handleFunc) handleFunc { | ||||||
| 	return s.ensureUser(func(w http.ResponseWriter, r *http.Request, v *visitor) error { | 	return s.ensureUser(func(w http.ResponseWriter, r *http.Request, v *visitor) error { | ||||||
| 		if v.user.Billing.StripeCustomerID == "" { | 		if v.User().Billing.StripeCustomerID == "" { | ||||||
| 			return errHTTPBadRequestNotAPaidUser | 			return errHTTPBadRequestNotAPaidUser | ||||||
| 		} | 		} | ||||||
| 		return next(w, r, v) | 		return next(w, r, v) | ||||||
|  | @ -51,9 +51,6 @@ func (s *Server) ensureStripeCustomer(next handleFunc) handleFunc { | ||||||
| 
 | 
 | ||||||
| func (s *Server) withAccountSync(next handleFunc) handleFunc { | func (s *Server) withAccountSync(next handleFunc) handleFunc { | ||||||
| 	return func(w http.ResponseWriter, r *http.Request, v *visitor) error { | 	return func(w http.ResponseWriter, r *http.Request, v *visitor) error { | ||||||
| 		if v.user == nil { |  | ||||||
| 			return next(w, r, v) |  | ||||||
| 		} |  | ||||||
| 		err := next(w, r, v) | 		err := next(w, r, v) | ||||||
| 		if err == nil { | 		if err == nil { | ||||||
| 			s.publishSyncEventAsync(v) | 			s.publishSyncEventAsync(v) | ||||||
|  |  | ||||||
|  | @ -54,7 +54,7 @@ var ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // handleBillingTiersGet returns all available paid tiers, and the free tier. This is to populate the upgrade dialog | // handleBillingTiersGet returns all available paid tiers, and the free tier. This is to populate the upgrade dialog | ||||||
| // in the UI. Note that this endpoint does NOT have a user context (no v.user!). | // in the UI. Note that this endpoint does NOT have a user context (no u!). | ||||||
| func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ *visitor) error { | func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ *visitor) error { | ||||||
| 	tiers, err := s.userManager.Tiers() | 	tiers, err := s.userManager.Tiers() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -107,7 +107,8 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ | ||||||
| // handleAccountBillingSubscriptionCreate creates a Stripe checkout flow to create a user subscription. The tier | // handleAccountBillingSubscriptionCreate creates a Stripe checkout flow to create a user subscription. The tier | ||||||
| // will be updated by a subsequent webhook from Stripe, once the subscription becomes active. | // will be updated by a subsequent webhook from Stripe, once the subscription becomes active. | ||||||
| func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error { | func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error { | ||||||
| 	if v.user.Billing.StripeSubscriptionID != "" { | 	u := v.User() | ||||||
|  | 	if u.Billing.StripeSubscriptionID != "" { | ||||||
| 		return errHTTPBadRequestBillingSubscriptionExists | 		return errHTTPBadRequestBillingSubscriptionExists | ||||||
| 	} | 	} | ||||||
| 	req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit, false) | 	req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit, false) | ||||||
|  | @ -122,9 +123,9 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r | ||||||
| 	} | 	} | ||||||
| 	log.Info("%s Creating Stripe checkout flow", logHTTPPrefix(v, r)) | 	log.Info("%s Creating Stripe checkout flow", logHTTPPrefix(v, r)) | ||||||
| 	var stripeCustomerID *string | 	var stripeCustomerID *string | ||||||
| 	if v.user.Billing.StripeCustomerID != "" { | 	if u.Billing.StripeCustomerID != "" { | ||||||
| 		stripeCustomerID = &v.user.Billing.StripeCustomerID | 		stripeCustomerID = &u.Billing.StripeCustomerID | ||||||
| 		stripeCustomer, err := s.stripe.GetCustomer(v.user.Billing.StripeCustomerID) | 		stripeCustomer, err := s.stripe.GetCustomer(u.Billing.StripeCustomerID) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} else if stripeCustomer.Subscriptions != nil && len(stripeCustomer.Subscriptions.Data) > 0 { | 		} else if stripeCustomer.Subscriptions != nil && len(stripeCustomer.Subscriptions.Data) > 0 { | ||||||
|  | @ -134,7 +135,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r | ||||||
| 	successURL := s.config.BaseURL + apiAccountBillingSubscriptionCheckoutSuccessTemplate | 	successURL := s.config.BaseURL + apiAccountBillingSubscriptionCheckoutSuccessTemplate | ||||||
| 	params := &stripe.CheckoutSessionParams{ | 	params := &stripe.CheckoutSessionParams{ | ||||||
| 		Customer:            stripeCustomerID, // A user may have previously deleted their subscription | 		Customer:            stripeCustomerID, // A user may have previously deleted their subscription | ||||||
| 		ClientReferenceID:   &v.user.ID, | 		ClientReferenceID:   &u.ID, | ||||||
| 		SuccessURL:          &successURL, | 		SuccessURL:          &successURL, | ||||||
| 		Mode:                stripe.String(string(stripe.CheckoutSessionModeSubscription)), | 		Mode:                stripe.String(string(stripe.CheckoutSessionModeSubscription)), | ||||||
| 		AllowPromotionCodes: stripe.Bool(true), | 		AllowPromotionCodes: stripe.Bool(true), | ||||||
|  | @ -146,7 +147,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r | ||||||
| 		}, | 		}, | ||||||
| 		Params: stripe.Params{ | 		Params: stripe.Params{ | ||||||
| 			Metadata: map[string]string{ | 			Metadata: map[string]string{ | ||||||
| 				"user_id": v.user.ID, | 				"user_id": u.ID, | ||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
|  | @ -164,7 +165,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r | ||||||
| // the session ID in the URL to retrieve the Stripe subscription and update the local database. This is the first | // the session ID in the URL to retrieve the Stripe subscription and update the local database. This is the first | ||||||
| // and only time we can map the local username with the Stripe customer ID. | // and only time we can map the local username with the Stripe customer ID. | ||||||
| func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, v *visitor) error { | func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, v *visitor) error { | ||||||
| 	// We don't have a v.user in this endpoint, only a userManager! | 	// We don't have v.User() in this endpoint, only a userManager! | ||||||
| 	matches := apiAccountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path) | 	matches := apiAccountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path) | ||||||
| 	if len(matches) != 2 { | 	if len(matches) != 2 { | ||||||
| 		return errHTTPInternalErrorInvalidPath | 		return errHTTPInternalErrorInvalidPath | ||||||
|  | @ -212,7 +213,8 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr | ||||||
| // handleAccountBillingSubscriptionUpdate updates an existing Stripe subscription to a new price, and updates | // handleAccountBillingSubscriptionUpdate updates an existing Stripe subscription to a new price, and updates | ||||||
| // a user's tier accordingly. This endpoint only works if there is an existing subscription. | // a user's tier accordingly. This endpoint only works if there is an existing subscription. | ||||||
| func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error { | func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error { | ||||||
| 	if v.user.Billing.StripeSubscriptionID == "" { | 	u := v.User() | ||||||
|  | 	if u.Billing.StripeSubscriptionID == "" { | ||||||
| 		return errNoBillingSubscription | 		return errNoBillingSubscription | ||||||
| 	} | 	} | ||||||
| 	req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit, false) | 	req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit, false) | ||||||
|  | @ -223,8 +225,8 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	log.Info("%s Changing billing tier to %s (price %s) for subscription %s", logHTTPPrefix(v, r), tier.Code, tier.StripePriceID, v.user.Billing.StripeSubscriptionID) | 	log.Info("%s Changing billing tier to %s (price %s) for subscription %s", logHTTPPrefix(v, r), tier.Code, tier.StripePriceID, u.Billing.StripeSubscriptionID) | ||||||
| 	sub, err := s.stripe.GetSubscription(v.user.Billing.StripeSubscriptionID) | 	sub, err := s.stripe.GetSubscription(u.Billing.StripeSubscriptionID) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  | @ -248,12 +250,13 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r | ||||||
| // handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user, | // handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user, | ||||||
| // and cancelling the Stripe subscription entirely | // and cancelling the Stripe subscription entirely | ||||||
| func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error { | func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error { | ||||||
| 	log.Info("%s Deleting billing subscription %s", logHTTPPrefix(v, r), v.user.Billing.StripeSubscriptionID) | 	u := v.User() | ||||||
| 	if v.user.Billing.StripeSubscriptionID != "" { | 	log.Info("%s Deleting billing subscription %s", logHTTPPrefix(v, r), u.Billing.StripeSubscriptionID) | ||||||
|  | 	if u.Billing.StripeSubscriptionID != "" { | ||||||
| 		params := &stripe.SubscriptionParams{ | 		params := &stripe.SubscriptionParams{ | ||||||
| 			CancelAtPeriodEnd: stripe.Bool(true), | 			CancelAtPeriodEnd: stripe.Bool(true), | ||||||
| 		} | 		} | ||||||
| 		_, err := s.stripe.UpdateSubscription(v.user.Billing.StripeSubscriptionID, params) | 		_, err := s.stripe.UpdateSubscription(u.Billing.StripeSubscriptionID, params) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
|  | @ -264,12 +267,13 @@ func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r | ||||||
| // handleAccountBillingPortalSessionCreate creates a session to the customer billing portal, and returns the | // handleAccountBillingPortalSessionCreate creates a session to the customer billing portal, and returns the | ||||||
| // redirect URL. The billing portal allows customers to change their payment methods, and cancel the subscription. | // redirect URL. The billing portal allows customers to change their payment methods, and cancel the subscription. | ||||||
| func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error { | func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error { | ||||||
| 	if v.user.Billing.StripeCustomerID == "" { | 	u := v.User() | ||||||
|  | 	if u.Billing.StripeCustomerID == "" { | ||||||
| 		return errHTTPBadRequestNotAPaidUser | 		return errHTTPBadRequestNotAPaidUser | ||||||
| 	} | 	} | ||||||
| 	log.Info("%s Creating billing portal session", logHTTPPrefix(v, r)) | 	log.Info("%s Creating billing portal session", logHTTPPrefix(v, r)) | ||||||
| 	params := &stripe.BillingPortalSessionParams{ | 	params := &stripe.BillingPortalSessionParams{ | ||||||
| 		Customer:  stripe.String(v.user.Billing.StripeCustomerID), | 		Customer:  stripe.String(u.Billing.StripeCustomerID), | ||||||
| 		ReturnURL: stripe.String(s.config.BaseURL), | 		ReturnURL: stripe.String(s.config.BaseURL), | ||||||
| 	} | 	} | ||||||
| 	ps, err := s.stripe.NewPortalSession(params) | 	ps, err := s.stripe.NewPortalSession(params) | ||||||
|  | @ -284,8 +288,8 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, | ||||||
| 
 | 
 | ||||||
| // handleAccountBillingWebhook handles incoming Stripe webhooks. It mainly keeps the local user database in sync | // handleAccountBillingWebhook handles incoming Stripe webhooks. It mainly keeps the local user database in sync | ||||||
| // with the Stripe view of the world. This endpoint is authorized via the Stripe webhook secret. Note that the | // with the Stripe view of the world. This endpoint is authorized via the Stripe webhook secret. Note that the | ||||||
| // visitor (v) in this endpoint is the Stripe API, so we don't have v.user available. | // visitor (v) in this endpoint is the Stripe API, so we don't have u available. | ||||||
| func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Request, _ *visitor) error { | func (s *Server) handleAccountBillingWebhook(_ http.ResponseWriter, r *http.Request, _ *visitor) error { | ||||||
| 	stripeSignature := r.Header.Get("Stripe-Signature") | 	stripeSignature := r.Header.Get("Stripe-Signature") | ||||||
| 	if stripeSignature == "" { | 	if stripeSignature == "" { | ||||||
| 		return errHTTPBadRequestBillingRequestInvalid | 		return errHTTPBadRequestBillingRequestInvalid | ||||||
|  |  | ||||||
|  | @ -30,6 +30,7 @@ const ( | ||||||
| 	tokenMaxCount                   = 10 // Only keep this many tokens in the table per user | 	tokenMaxCount                   = 10 // Only keep this many tokens in the table per user | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // Default constants that may be overridden by configs | ||||||
| const ( | const ( | ||||||
| 	DefaultUserStatsQueueWriterInterval = 33 * time.Second | 	DefaultUserStatsQueueWriterInterval = 33 * time.Second | ||||||
| 	DefaultUserPasswordBcryptCost       = 10 | 	DefaultUserPasswordBcryptCost       = 10 | ||||||
|  | @ -1195,6 +1196,7 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) { | ||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Close closes the underlying database | ||||||
| func (a *Manager) Close() error { | func (a *Manager) Close() error { | ||||||
| 	return a.db.Close() | 	return a.db.Close() | ||||||
| } | } | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 binwiederhier
						binwiederhier