1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2024-11-22 03:13:33 +01:00

Delayed deletion

This commit is contained in:
binwiederhier 2023-01-22 22:21:30 -05:00
parent 9c082a8331
commit 954d919361
14 changed files with 280 additions and 131 deletions

View file

@ -16,8 +16,7 @@ import (
) )
const ( const (
tierReset = "-" tierReset = "-"
createdByCLI = "cli"
) )
func init() { func init() {
@ -197,7 +196,7 @@ func execUserAdd(c *cli.Context) error {
password = p password = p
} }
if err := manager.AddUser(username, password, role, createdByCLI); err != nil { if err := manager.AddUser(username, password, role); err != nil {
return err return err
} }
fmt.Fprintf(c.App.ErrWriter, "user %s added with role %s\n", username, role) fmt.Fprintf(c.App.ErrWriter, "user %s added with role %s\n", username, role)

View file

@ -39,12 +39,10 @@ TODO
-- --
- Reservation: Kill existing subscribers when topic is reserved (deadcade) - Reservation: Kill existing subscribers when topic is reserved (deadcade)
- Rate limiting: Sensitive endpoints (account/login/change-password/...)
- Stripe: Add metadata to customer
- Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben) - Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
- Reservation (UI): Ask for confirmation when removing reservation (deadcade) - Reservation (UI): Ask for confirmation when removing reservation (deadcade)
- Logging: Add detailed logging with username/customerID for all Stripe events (phil)
- Rate limiting: Sensitive endpoints (account/login/change-password/...)
- Stripe webhook: Do not respond wih error if user does not exist (after account deletion)
- Stripe: Add metadata to customer
races: races:
- v.user --> see publishSyncEventAsync() test - v.user --> see publishSyncEventAsync() test
@ -53,7 +51,7 @@ payments:
- reconciliation - reconciliation
delete messages + reserved topics on ResetTier delete attachments in access.go delete messages + reserved topics on ResetTier delete attachments in access.go
account deletion should delete messages and reservations and attachments
Limits & rate limiting: Limits & rate limiting:
rate limiting weirdness. wth is going on? rate limiting weirdness. wth is going on?
@ -1256,11 +1254,14 @@ func (s *Server) execManager() {
s.mu.Unlock() s.mu.Unlock()
log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors) log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors)
// Delete expired user tokens // Delete expired user tokens and users
if s.userManager != nil { if s.userManager != nil {
if err := s.userManager.RemoveExpiredTokens(); err != nil { if err := s.userManager.RemoveExpiredTokens(); err != nil {
log.Warn("Error expiring user tokens: %s", err.Error()) log.Warn("Error expiring user tokens: %s", err.Error())
} }
if err := s.userManager.RemoveDeletedUsers(); err != nil {
log.Warn("Error deleting soft-deleted users: %s", err.Error())
}
} }
// Delete expired attachments // Delete expired attachments
@ -1283,7 +1284,7 @@ func (s *Server) execManager() {
} }
} }
// DeleteMessages message cache // Prune messages
log.Debug("Manager: Pruning messages") log.Debug("Manager: Pruning messages")
expiredMessageIDs, err := s.messageCache.MessagesExpired() expiredMessageIDs, err := s.messageCache.MessagesExpired()
if err != nil { if err != nil {

View file

@ -11,7 +11,6 @@ import (
const ( const (
subscriptionIDLength = 16 subscriptionIDLength = 16
createdByAPI = "api"
syncTopicAccountSyncEvent = "sync" syncTopicAccountSyncEvent = "sync"
) )
@ -34,7 +33,7 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
if v.accountLimiter != nil && !v.accountLimiter.Allow() { if v.accountLimiter != nil && !v.accountLimiter.Allow() {
return errHTTPTooManyRequestsLimitAccountCreation return errHTTPTooManyRequestsLimitAccountCreation
} }
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser, createdByAPI); err != nil { // TODO this should return a User if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { // TODO this should return a User
return err return err
} }
return s.writeJSON(w, newSuccessResponse()) return s.writeJSON(w, newSuccessResponse())
@ -118,18 +117,20 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis
return s.writeJSON(w, response) return s.writeJSON(w, response)
} }
func (s *Server) handleAccountDelete(w http.ResponseWriter, _ *http.Request, v *visitor) error { func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user.Billing.StripeSubscriptionID != "" { if v.user.Billing.StripeSubscriptionID != "" {
log.Info("Deleting user %s (billing customer: %s, billing subscription: %s)", v.user.Name, v.user.Billing.StripeCustomerID, v.user.Billing.StripeSubscriptionID) log.Info("%s Canceling billing subscription %s", logHTTPPrefix(v, r), v.user.Billing.StripeSubscriptionID)
if v.user.Billing.StripeSubscriptionID != "" { if v.user.Billing.StripeSubscriptionID != "" {
if _, err := s.stripe.CancelSubscription(v.user.Billing.StripeSubscriptionID); err != nil { if _, err := s.stripe.CancelSubscription(v.user.Billing.StripeSubscriptionID); err != nil {
return err return err
} }
} }
} else { if err := s.maybeRemoveExcessReservations(logHTTPPrefix(v, r), v.user, 0); err != nil {
log.Info("Deleting user %s", v.user.Name) return err
}
} }
if err := s.userManager.RemoveUser(v.user.Name); err != nil { log.Info("%s Marking user %s as deleted", logHTTPPrefix(v, r), v.user.Name)
if err := s.userManager.MarkUserRemoved(v.user); err != nil {
return err return err
} }
return s.writeJSON(w, newSuccessResponse()) return s.writeJSON(w, newSuccessResponse())

View file

@ -67,8 +67,8 @@ func TestAccount_Signup_AsUser(t *testing.T) {
conf.EnableSignup = true conf.EnableSignup = true
s := newTestServer(t, conf) s := newTestServer(t, conf)
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test")) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test")) require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
rr := request(t, s, "POST", "/v1/account", `{"username":"emma", "password":"emma"}`, map[string]string{ rr := request(t, s, "POST", "/v1/account", `{"username":"emma", "password":"emma"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"), "Authorization": util.BasicAuth("phil", "phil"),
@ -133,7 +133,7 @@ func TestAccount_Get_Anonymous(t *testing.T) {
func TestAccount_ChangeSettings(t *testing.T) { func TestAccount_ChangeSettings(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t)) s := newTestServer(t, newTestConfigWithAuthFile(t))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
user, _ := s.userManager.User("phil") user, _ := s.userManager.User("phil")
token, _ := s.userManager.CreateToken(user) token, _ := s.userManager.CreateToken(user)
@ -160,7 +160,7 @@ func TestAccount_ChangeSettings(t *testing.T) {
func TestAccount_Subscription_AddUpdateDelete(t *testing.T) { func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t)) s := newTestServer(t, newTestConfigWithAuthFile(t))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
rr := request(t, s, "POST", "/v1/account/subscription", `{"base_url": "http://abc.com", "topic": "def"}`, map[string]string{ rr := request(t, s, "POST", "/v1/account/subscription", `{"base_url": "http://abc.com", "topic": "def"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"), "Authorization": util.BasicAuth("phil", "phil"),
@ -210,7 +210,7 @@ func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
func TestAccount_ChangePassword(t *testing.T) { func TestAccount_ChangePassword(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t)) s := newTestServer(t, newTestConfigWithAuthFile(t))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
rr := request(t, s, "POST", "/v1/account/password", `{"password": "phil", "new_password": "new password"}`, map[string]string{ rr := request(t, s, "POST", "/v1/account/password", `{"password": "phil", "new_password": "new password"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"), "Authorization": util.BasicAuth("phil", "phil"),
@ -237,7 +237,7 @@ func TestAccount_ChangePassword_NoAccount(t *testing.T) {
func TestAccount_ExtendToken(t *testing.T) { func TestAccount_ExtendToken(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t)) s := newTestServer(t, newTestConfigWithAuthFile(t))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
rr := request(t, s, "POST", "/v1/account/token", "", map[string]string{ rr := request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"), "Authorization": util.BasicAuth("phil", "phil"),
@ -260,7 +260,7 @@ func TestAccount_ExtendToken(t *testing.T) {
func TestAccount_ExtendToken_NoTokenProvided(t *testing.T) { func TestAccount_ExtendToken_NoTokenProvided(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t)) s := newTestServer(t, newTestConfigWithAuthFile(t))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
rr := request(t, s, "PATCH", "/v1/account/token", "", map[string]string{ rr := request(t, s, "PATCH", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"), // Not Bearer! "Authorization": util.BasicAuth("phil", "phil"), // Not Bearer!
@ -271,7 +271,7 @@ func TestAccount_ExtendToken_NoTokenProvided(t *testing.T) {
func TestAccount_DeleteToken(t *testing.T) { func TestAccount_DeleteToken(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t)) s := newTestServer(t, newTestConfigWithAuthFile(t))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
rr := request(t, s, "POST", "/v1/account/token", "", map[string]string{ rr := request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"), "Authorization": util.BasicAuth("phil", "phil"),
@ -324,10 +324,15 @@ func TestAccount_Delete_Success(t *testing.T) {
}) })
require.Equal(t, 200, rr.Code) require.Equal(t, 200, rr.Code)
// Account was marked deleted
rr = request(t, s, "GET", "/v1/account", "", map[string]string{ rr = request(t, s, "GET", "/v1/account", "", map[string]string{
"Authorization": util.BasicAuth("phil", "mypass"), "Authorization": util.BasicAuth("phil", "mypass"),
}) })
require.Equal(t, 401, rr.Code) require.Equal(t, 401, rr.Code)
// Cannot re-create account, since still exists
rr = request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil)
require.Equal(t, 409, rr.Code)
} }
func TestAccount_Delete_Not_Allowed(t *testing.T) { func TestAccount_Delete_Not_Allowed(t *testing.T) {
@ -360,7 +365,7 @@ func TestAccount_Reservation_AddAdminSuccess(t *testing.T) {
conf := newTestConfigWithAuthFile(t) conf := newTestConfigWithAuthFile(t)
conf.EnableSignup = true conf.EnableSignup = true
s := newTestServer(t, conf) s := newTestServer(t, conf)
require.Nil(t, s.userManager.AddUser("phil", "adminpass", user.RoleAdmin, "unit-test")) require.Nil(t, s.userManager.AddUser("phil", "adminpass", user.RoleAdmin))
rr := request(t, s, "POST", "/v1/account/reservation", `{"topic":"mytopic","everyone":"deny-all"}`, map[string]string{ rr := request(t, s, "POST", "/v1/account/reservation", `{"topic":"mytopic","everyone":"deny-all"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "adminpass"), "Authorization": util.BasicAuth("phil", "adminpass"),

View file

@ -18,15 +18,10 @@ import (
"io" "io"
"net/http" "net/http"
"net/netip" "net/netip"
"strings"
"time" "time"
) )
var (
errNotAPaidTier = errors.New("tier does not have billing price identifier")
errMultipleBillingSubscriptions = errors.New("cannot have multiple billing subscriptions")
errNoBillingSubscription = errors.New("user does not have an active billing subscription")
)
// Payments in ntfy are done via Stripe. // Payments in ntfy are done via Stripe.
// //
// Pretty much all payments related things are in this file. The following processes // Pretty much all payments related things are in this file. The following processes
@ -49,6 +44,16 @@ var (
// This is used to keep the local user database fields up to date. Stripe is the source of truth. // This is used to keep the local user database fields up to date. Stripe is the source of truth.
// What Stripe says is mirrored and not questioned. // What Stripe says is mirrored and not questioned.
var (
errNotAPaidTier = errors.New("tier does not have billing price identifier")
errMultipleBillingSubscriptions = errors.New("cannot have multiple billing subscriptions")
errNoBillingSubscription = errors.New("user does not have an active billing subscription")
)
var (
retryUserDelays = []time.Duration{3 * time.Second, 5 * time.Second, 7 * time.Second}
)
// handleBillingTiersGet returns all available paid tiers, and the free tier. This is to populate the upgrade dialog // handleBillingTiersGet returns all available paid tiers, and the free tier. This is to populate the upgrade dialog
// in the UI. Note that this endpoint does NOT have a user context (no v.user!). // in the UI. Note that this endpoint does NOT have a user context (no v.user!).
func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ *visitor) error { func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
@ -114,7 +119,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
} else if tier.StripePriceID == "" { } else if tier.StripePriceID == "" {
return errNotAPaidTier return errNotAPaidTier
} }
log.Info("Stripe: No existing subscription, creating checkout flow") log.Info("%s Creating Stripe checkout flow", logHTTPPrefix(v, r))
var stripeCustomerID *string var stripeCustomerID *string
if v.user.Billing.StripeCustomerID != "" { if v.user.Billing.StripeCustomerID != "" {
stripeCustomerID = &v.user.Billing.StripeCustomerID stripeCustomerID = &v.user.Billing.StripeCustomerID
@ -138,9 +143,6 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
Quantity: stripe.Int64(1), Quantity: stripe.Int64(1),
}, },
}, },
/*AutomaticTax: &stripe.CheckoutSessionAutomaticTaxParams{
Enabled: stripe.Bool(true),
},*/
} }
sess, err := s.stripe.NewCheckoutSession(params) sess, err := s.stripe.NewCheckoutSession(params)
if err != nil { if err != nil {
@ -155,7 +157,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
// handleAccountBillingSubscriptionCreateSuccess is called after the Stripe checkout session has succeeded. We use // handleAccountBillingSubscriptionCreateSuccess is called after the Stripe checkout session has succeeded. We use
// the session ID in the URL to retrieve the Stripe subscription and update the local database. This is the first // the session ID in the URL to retrieve the Stripe subscription and update the local database. This is the first
// and only time we can map the local username with the Stripe customer ID. // and only time we can map the local username with the Stripe customer ID.
func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error { func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, v *visitor) error {
// We don't have a v.user in this endpoint, only a userManager! // We don't have a v.user in this endpoint, only a userManager!
matches := apiAccountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path) matches := apiAccountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
if len(matches) != 2 { if len(matches) != 2 {
@ -182,7 +184,8 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
if err != nil { if err != nil {
return err return err
} }
if err := s.updateSubscriptionAndTier(u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil { v.SetUser(u)
if err := s.updateSubscriptionAndTier(logHTTPPrefix(v, r), u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
return err return err
} }
http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther) http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther)
@ -203,7 +206,7 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
if err != nil { if err != nil {
return err return err
} }
log.Info("Stripe: Changing tier and subscription to %s", tier.Code) log.Info("%s Changing billing tier to %s (price %s) for subscription %s", logHTTPPrefix(v, r), tier.Code, tier.StripePriceID, v.user.Billing.StripeSubscriptionID)
sub, err := s.stripe.GetSubscription(v.user.Billing.StripeSubscriptionID) sub, err := s.stripe.GetSubscription(v.user.Billing.StripeSubscriptionID)
if err != nil { if err != nil {
return err return err
@ -228,6 +231,7 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user, // handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
// and cancelling the Stripe subscription entirely // and cancelling the Stripe subscription entirely
func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error { func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
log.Info("%s Deleting billing subscription %s", logHTTPPrefix(v, r), v.user.Billing.StripeSubscriptionID)
if v.user.Billing.StripeSubscriptionID != "" { if v.user.Billing.StripeSubscriptionID != "" {
params := &stripe.SubscriptionParams{ params := &stripe.SubscriptionParams{
CancelAtPeriodEnd: stripe.Bool(true), CancelAtPeriodEnd: stripe.Bool(true),
@ -246,6 +250,7 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter,
if v.user.Billing.StripeCustomerID == "" { if v.user.Billing.StripeCustomerID == "" {
return errHTTPBadRequestNotAPaidUser return errHTTPBadRequestNotAPaidUser
} }
log.Info("%s Creating billing portal session", logHTTPPrefix(v, r))
params := &stripe.BillingPortalSessionParams{ params := &stripe.BillingPortalSessionParams{
Customer: stripe.String(v.user.Billing.StripeCustomerID), Customer: stripe.String(v.user.Billing.StripeCustomerID),
ReturnURL: stripe.String(s.config.BaseURL), ReturnURL: stripe.String(s.config.BaseURL),
@ -280,28 +285,30 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ
} else if event.Data == nil || event.Data.Raw == nil { } else if event.Data == nil || event.Data.Raw == nil {
return errHTTPBadRequestBillingRequestInvalid return errHTTPBadRequestBillingRequestInvalid
} }
log.Info("Stripe: webhook event %s received", event.Type)
switch event.Type { switch event.Type {
case "customer.subscription.updated": case "customer.subscription.updated":
return s.handleAccountBillingWebhookSubscriptionUpdated(event.Data.Raw) return s.handleAccountBillingWebhookSubscriptionUpdated(event.Data.Raw)
case "customer.subscription.deleted": case "customer.subscription.deleted":
return s.handleAccountBillingWebhookSubscriptionDeleted(event.Data.Raw) return s.handleAccountBillingWebhookSubscriptionDeleted(event.Data.Raw)
default: default:
log.Warn("STRIPE Unhandled webhook event %s received", event.Type)
return nil return nil
} }
} }
func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMessage) error { func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMessage) error {
r, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event))) ev, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event)))
if err != nil { if err != nil {
return err return err
} else if r.ID == "" || r.Customer == "" || r.Status == "" || r.CurrentPeriodEnd == 0 || r.Items == nil || len(r.Items.Data) != 1 || r.Items.Data[0].Price == nil || r.Items.Data[0].Price.ID == "" { } else if ev.ID == "" || ev.Customer == "" || ev.Status == "" || ev.CurrentPeriodEnd == 0 || ev.Items == nil || len(ev.Items.Data) != 1 || ev.Items.Data[0].Price == nil || ev.Items.Data[0].Price.ID == "" {
return errHTTPBadRequestBillingRequestInvalid return errHTTPBadRequestBillingRequestInvalid
} }
subscriptionID, priceID := r.ID, r.Items.Data[0].Price.ID subscriptionID, priceID := ev.ID, ev.Items.Data[0].Price.ID
log.Info("Stripe: customer %s: Updating subscription to status %s, with price %s", r.Customer, r.Status, priceID) log.Info("%s Updating subscription to status %s, with price %s", logStripePrefix(ev.Customer, ev.ID), ev.Status, priceID)
u, err := s.userManager.UserByStripeCustomer(r.Customer) userFn := func() (*user.User, error) {
return s.userManager.UserByStripeCustomer(ev.Customer)
}
u, err := util.Retry[user.User](userFn, retryUserDelays...)
if err != nil { if err != nil {
return err return err
} }
@ -309,7 +316,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
if err != nil { if err != nil {
return err return err
} }
if err := s.updateSubscriptionAndTier(u, tier, r.Customer, subscriptionID, r.Status, r.CurrentPeriodEnd, r.CancelAt); err != nil { if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
return err return err
} }
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified())) s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
@ -317,47 +324,56 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
} }
func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error { func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error {
r, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event))) ev, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event)))
if err != nil { if err != nil {
return err return err
} else if r.Customer == "" { } else if ev.Customer == "" {
return errHTTPBadRequestBillingRequestInvalid return errHTTPBadRequestBillingRequestInvalid
} }
log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", r.Customer) log.Info("%s Subscription deleted, downgrading to unpaid tier", logStripePrefix(ev.Customer, ev.ID))
u, err := s.userManager.UserByStripeCustomer(r.Customer) u, err := s.userManager.UserByStripeCustomer(ev.Customer)
if err != nil { if err != nil {
return err return err
} }
if err := s.updateSubscriptionAndTier(u, nil, r.Customer, "", "", 0, 0); err != nil { if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, nil, ev.Customer, "", "", 0, 0); err != nil {
return err return err
} }
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified())) s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
return nil return nil
} }
func (s *Server) updateSubscriptionAndTier(u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error { // maybeRemoveExcessReservations deletes topic reservations for the given user (if too many for tier),
// Remove excess reservations (if too many for tier), and mark associated messages deleted // and marks associated messages for the topics as deleted. This also eventually deletes attachments.
// The process relies on the manager to perform the actual deletions (see runManager).
func (s *Server) maybeRemoveExcessReservations(logPrefix string, u *user.User, reservationsLimit int64) error {
reservations, err := s.userManager.Reservations(u.Name) reservations, err := s.userManager.Reservations(u.Name)
if err != nil { if err != nil {
return err return err
} else if int64(len(reservations)) <= reservationsLimit {
return nil
} }
topics := make([]string, 0)
for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- {
topics = append(topics, reservations[i].Topic)
}
log.Info("%s Removing excess reservations for topics %s", logPrefix, strings.Join(topics, ", "))
if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil {
return err
}
if err := s.messageCache.ExpireMessages(topics...); err != nil {
return err
}
return nil
}
func (s *Server) updateSubscriptionAndTier(logPrefix string, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
reservationsLimit := visitorDefaultReservationsLimit reservationsLimit := visitorDefaultReservationsLimit
if tier != nil { if tier != nil {
reservationsLimit = tier.ReservationsLimit reservationsLimit = tier.ReservationsLimit
} }
if int64(len(reservations)) > reservationsLimit { if err := s.maybeRemoveExcessReservations(logPrefix, u, reservationsLimit); err != nil {
topics := make([]string, 0) return err
for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- {
topics = append(topics, reservations[i].Topic)
}
if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil {
return err
}
if err := s.messageCache.ExpireMessages(topics...); err != nil {
return err
}
} }
// Change or remove tier
if tier == nil { if tier == nil {
if err := s.userManager.ResetTier(u.Name); err != nil { if err := s.userManager.ResetTier(u.Name); err != nil {
return err return err

View file

@ -34,7 +34,7 @@ func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
Code: "pro", Code: "pro",
StripePriceID: "price_123", StripePriceID: "price_123",
})) }))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
// Create subscription // Create subscription
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{ response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
@ -69,7 +69,7 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
Code: "pro", Code: "pro",
StripePriceID: "price_123", StripePriceID: "price_123",
})) }))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
u, err := s.userManager.User("phil") u, err := s.userManager.User("phil")
require.Nil(t, err) require.Nil(t, err)
@ -110,7 +110,7 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
Code: "pro", Code: "pro",
StripePriceID: "price_123", StripePriceID: "price_123",
})) }))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
u, err := s.userManager.User("phil") u, err := s.userManager.User("phil")
require.Nil(t, err) require.Nil(t, err)
@ -174,7 +174,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
AttachmentFileSizeLimit: 1000000, AttachmentFileSizeLimit: 1000000,
AttachmentTotalSizeLimit: 1000000, AttachmentTotalSizeLimit: 1000000,
})) }))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "pro")) require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll)) require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
require.Nil(t, s.userManager.AddReservation("phil", "ztopic", user.PermissionDenyAll)) require.Nil(t, s.userManager.AddReservation("phil", "ztopic", user.PermissionDenyAll))

View file

@ -625,7 +625,7 @@ func TestServer_Auth_Success_Admin(t *testing.T) {
c := newTestConfigWithAuthFile(t) c := newTestConfigWithAuthFile(t)
s := newTestServer(t, c) s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test")) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{ response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"), "Authorization": util.BasicAuth("phil", "phil"),
@ -639,7 +639,7 @@ func TestServer_Auth_Success_User(t *testing.T) {
c.AuthDefault = user.PermissionDenyAll c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c) s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test")) require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AllowAccess("ben", "mytopic", user.PermissionReadWrite)) require.Nil(t, s.userManager.AllowAccess("ben", "mytopic", user.PermissionReadWrite))
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{ response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
@ -653,7 +653,7 @@ func TestServer_Auth_Success_User_MultipleTopics(t *testing.T) {
c.AuthDefault = user.PermissionDenyAll c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c) s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test")) require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AllowAccess("ben", "mytopic", user.PermissionReadWrite)) require.Nil(t, s.userManager.AllowAccess("ben", "mytopic", user.PermissionReadWrite))
require.Nil(t, s.userManager.AllowAccess("ben", "anothertopic", user.PermissionReadWrite)) require.Nil(t, s.userManager.AllowAccess("ben", "anothertopic", user.PermissionReadWrite))
@ -674,7 +674,7 @@ func TestServer_Auth_Fail_InvalidPass(t *testing.T) {
c.AuthDefault = user.PermissionDenyAll c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c) s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test")) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{ response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
"Authorization": util.BasicAuth("phil", "INVALID"), "Authorization": util.BasicAuth("phil", "INVALID"),
@ -687,7 +687,7 @@ func TestServer_Auth_Fail_Unauthorized(t *testing.T) {
c.AuthDefault = user.PermissionDenyAll c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c) s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test")) require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AllowAccess("ben", "sometopic", user.PermissionReadWrite)) // Not mytopic! require.Nil(t, s.userManager.AllowAccess("ben", "sometopic", user.PermissionReadWrite)) // Not mytopic!
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{ response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
@ -701,7 +701,7 @@ func TestServer_Auth_Fail_CannotPublish(t *testing.T) {
c.AuthDefault = user.PermissionReadWrite // Open by default c.AuthDefault = user.PermissionReadWrite // Open by default
s := newTestServer(t, c) s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test")) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, s.userManager.AllowAccess(user.Everyone, "private", user.PermissionDenyAll)) require.Nil(t, s.userManager.AllowAccess(user.Everyone, "private", user.PermissionDenyAll))
require.Nil(t, s.userManager.AllowAccess(user.Everyone, "announcements", user.PermissionRead)) require.Nil(t, s.userManager.AllowAccess(user.Everyone, "announcements", user.PermissionRead))
@ -731,7 +731,7 @@ func TestServer_Auth_ViaQuery(t *testing.T) {
c.AuthDefault = user.PermissionDenyAll c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c) s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("ben", "some pass", user.RoleAdmin, "unit-test")) require.Nil(t, s.userManager.AddUser("ben", "some pass", user.RoleAdmin))
u := fmt.Sprintf("/mytopic/json?poll=1&auth=%s", base64.RawURLEncoding.EncodeToString([]byte(util.BasicAuth("ben", "some pass")))) u := fmt.Sprintf("/mytopic/json?poll=1&auth=%s", base64.RawURLEncoding.EncodeToString([]byte(util.BasicAuth("ben", "some pass"))))
response := request(t, s, "GET", u, "", nil) response := request(t, s, "GET", u, "", nil)
@ -749,7 +749,7 @@ func TestServer_StatsResetter(t *testing.T) {
s := newTestServer(t, c) s := newTestServer(t, c)
go s.runStatsResetter() go s.runStatsResetter()
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AllowAccess("phil", "mytopic", user.PermissionReadWrite)) require.Nil(t, s.userManager.AllowAccess("phil", "mytopic", user.PermissionReadWrite))
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
@ -1137,7 +1137,7 @@ func TestServer_PublishWithTierBasedMessageLimitAndExpiry(t *testing.T) {
MessagesLimit: 5, MessagesLimit: 5,
MessagesExpiryDuration: -5 * time.Second, // Second, what a hack! MessagesExpiryDuration: -5 * time.Second, // Second, what a hack!
})) }))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "test")) require.Nil(t, s.userManager.ChangeTier("phil", "test"))
// Publish to reach message limit // Publish to reach message limit
@ -1369,7 +1369,7 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) {
AttachmentTotalSizeLimit: 200_000, AttachmentTotalSizeLimit: 200_000,
AttachmentExpiryDuration: sevenDays, // 7 days AttachmentExpiryDuration: sevenDays, // 7 days
})) }))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "test")) require.Nil(t, s.userManager.ChangeTier("phil", "test"))
// Publish and make sure we can retrieve it // Publish and make sure we can retrieve it
@ -1413,7 +1413,7 @@ func TestServer_PublishAttachmentWithTierBasedLimits(t *testing.T) {
AttachmentTotalSizeLimit: 200_000, AttachmentTotalSizeLimit: 200_000,
AttachmentExpiryDuration: 30 * time.Second, AttachmentExpiryDuration: 30 * time.Second,
})) }))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "test")) require.Nil(t, s.userManager.ChangeTier("phil", "test"))
// Publish small file as anonymous // Publish small file as anonymous

View file

@ -354,5 +354,6 @@ type apiStripeSubscriptionUpdatedEvent struct {
} }
type apiStripeSubscriptionDeletedEvent struct { type apiStripeSubscriptionDeletedEvent struct {
ID string `json:"id"`
Customer string `json:"customer"` Customer string `json:"customer"`
} }

View file

@ -49,7 +49,7 @@ func readQueryParam(r *http.Request, names ...string) string {
} }
func logMessagePrefix(v *visitor, m *message) string { func logMessagePrefix(v *visitor, m *message) string {
return fmt.Sprintf("%s/%s/%s", v.ip, m.Topic, m.ID) return fmt.Sprintf("%s/%s/%s", v.String(), m.Topic, m.ID)
} }
func logHTTPPrefix(v *visitor, r *http.Request) string { func logHTTPPrefix(v *visitor, r *http.Request) string {
@ -57,7 +57,14 @@ func logHTTPPrefix(v *visitor, r *http.Request) string {
if requestURI == "" { if requestURI == "" {
requestURI = r.URL.Path requestURI = r.URL.Path
} }
return fmt.Sprintf("%s HTTP %s %s", v.ip, r.Method, requestURI) return fmt.Sprintf("%s HTTP %s %s", v.String(), r.Method, requestURI)
}
func logStripePrefix(customerID, subscriptionID string) string {
if subscriptionID != "" {
return fmt.Sprintf("%s/%s STRIPE", customerID, subscriptionID)
}
return fmt.Sprintf("%s STRIPE", customerID)
} }
func logSMTPPrefix(state *smtp.ConnectionState) string { func logSMTPPrefix(state *smtp.ConnectionState) string {

View file

@ -2,6 +2,7 @@ package server
import ( import (
"errors" "errors"
"fmt"
"heckel.io/ntfy/user" "heckel.io/ntfy/user"
"net/netip" "net/netip"
"sync" "sync"
@ -119,6 +120,17 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
} }
} }
func (v *visitor) String() string {
v.mu.Lock()
defer v.mu.Unlock()
if v.user != nil && v.user.Billing.StripeCustomerID != "" {
return fmt.Sprintf("%s/%s/%s", v.ip.String(), v.user.ID, v.user.Billing.StripeCustomerID)
} else if v.user != nil {
return fmt.Sprintf("%s/%s", v.ip.String(), v.user.ID)
}
return v.ip.String()
}
func (v *visitor) RequestAllowed() error { func (v *visitor) RequestAllowed() error {
if !v.requestLimiter.Allow() { if !v.requestLimiter.Allow() {
return errVisitorLimitReached return errVisitorLimitReached
@ -216,6 +228,12 @@ func (v *visitor) ResetStats() {
} }
} }
func (v *visitor) SetUser(u *user.User) {
v.mu.Lock()
defer v.mu.Unlock()
v.user = u
}
func (v *visitor) Limits() *visitorLimits { func (v *visitor) Limits() *visitorLimits {
v.mu.Lock() v.mu.Lock()
defer v.mu.Unlock() defer v.mu.Unlock()

View file

@ -25,6 +25,7 @@ const (
userPasswordBcryptCost = 10 userPasswordBcryptCost = 10
userAuthIntentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match userPasswordBcryptCost userAuthIntentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match userPasswordBcryptCost
userStatsQueueWriterInterval = 33 * time.Second userStatsQueueWriterInterval = 33 * time.Second
userHardDeleteAfterDuration = 7 * 24 * time.Hour
tokenPrefix = "tk_" tokenPrefix = "tk_"
tokenLength = 32 tokenLength = 32
tokenMaxCount = 10 // Only keep this many tokens in the table per user tokenMaxCount = 10 // Only keep this many tokens in the table per user
@ -57,7 +58,7 @@ const (
CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id); CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
CREATE TABLE IF NOT EXISTS user ( CREATE TABLE IF NOT EXISTS user (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
tier_id INT, tier_id TEXT,
user TEXT NOT NULL, user TEXT NOT NULL,
pass TEXT NOT NULL, pass TEXT NOT NULL,
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL, role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
@ -70,8 +71,8 @@ const (
stripe_subscription_status TEXT, stripe_subscription_status TEXT,
stripe_subscription_paid_until INT, stripe_subscription_paid_until INT,
stripe_subscription_cancel_at INT, stripe_subscription_cancel_at INT,
created_by TEXT NOT NULL, created INT NOT NULL,
created_at INT NOT NULL, deleted INT,
FOREIGN KEY (tier_id) REFERENCES tier (id) FOREIGN KEY (tier_id) REFERENCES tier (id)
); );
CREATE UNIQUE INDEX idx_user ON user (user); CREATE UNIQUE INDEX idx_user ON user (user);
@ -98,8 +99,8 @@ const (
id INT PRIMARY KEY, id INT PRIMARY KEY,
version INT NOT NULL version INT NOT NULL
); );
INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at) INSERT INTO user (id, user, pass, role, sync_topic, created)
VALUES ('u_everyone', '*', '', 'anonymous', '', 'system', UNIXEPOCH()) VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', UNIXEPOCH())
ON CONFLICT (id) DO NOTHING; ON CONFLICT (id) DO NOTHING;
` `
createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;` createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;`
@ -108,26 +109,26 @@ const (
` `
selectUserByIDQuery = ` selectUserByIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
FROM user u FROM user u
LEFT JOIN tier t on t.id = u.tier_id LEFT JOIN tier t on t.id = u.tier_id
WHERE u.id = ? WHERE u.id = ?
` `
selectUserByNameQuery = ` selectUserByNameQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
FROM user u FROM user u
LEFT JOIN tier t on t.id = u.tier_id LEFT JOIN tier t on t.id = u.tier_id
WHERE user = ? WHERE user = ?
` `
selectUserByTokenQuery = ` selectUserByTokenQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
FROM user u FROM user u
JOIN user_token t on u.id = t.user_id JOIN user_token t on u.id = t.user_id
LEFT JOIN tier t on t.id = u.tier_id LEFT JOIN tier t on t.id = u.tier_id
WHERE t.token = ? AND t.expires >= ? WHERE t.token = ? AND t.expires >= ?
` `
selectUserByStripeCustomerIDQuery = ` selectUserByStripeCustomerIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
FROM user u FROM user u
LEFT JOIN tier t on t.id = u.tier_id LEFT JOIN tier t on t.id = u.tier_id
WHERE u.stripe_customer_id = ? WHERE u.stripe_customer_id = ?
@ -141,8 +142,8 @@ const (
` `
insertUserQuery = ` insertUserQuery = `
INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at) INSERT INTO user (id, user, pass, role, sync_topic, created)
VALUES (?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?)
` `
selectUsernamesQuery = ` selectUsernamesQuery = `
SELECT user SELECT user
@ -159,6 +160,8 @@ const (
updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE user = ?` updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE user = ?`
updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ? WHERE id = ?` updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ? WHERE id = ?`
updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0` updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0`
updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
deleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
deleteUserQuery = `DELETE FROM user WHERE user = ?` deleteUserQuery = `DELETE FROM user WHERE user = ?`
upsertUserAccessQuery = ` upsertUserAccessQuery = `
@ -214,7 +217,8 @@ const (
selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?` selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES (?, ?, ?)` insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES (?, ?, ?)`
updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?` updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
deleteTokenQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?` deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires < ?` deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires < ?`
deleteExcessTokensQuery = ` deleteExcessTokensQuery = `
DELETE FROM user_token DELETE FROM user_token
@ -268,8 +272,8 @@ const (
` `
migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old` migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
migrate1To2InsertUserNoTx = ` migrate1To2InsertUserNoTx = `
INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at) INSERT INTO user (id, user, pass, role, sync_topic, created)
SELECT ?, user, pass, role, ?, 'admin', UNIXEPOCH() FROM user_old WHERE user = ? SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
` `
migrate1To2InsertFromOldTablesAndDropNoTx = ` migrate1To2InsertFromOldTablesAndDropNoTx = `
INSERT INTO user_access (user_id, topic, read, write) INSERT INTO user_access (user_id, topic, read, write)
@ -320,9 +324,9 @@ func newManager(filename, startupQueries string, defaultAccess Permission, stats
return manager, nil return manager, nil
} }
// Authenticate checks username and password and returns a User if correct. The method // Authenticate checks username and password and returns a User if correct, and the user has not been
// returns in constant-ish time, regardless of whether the user exists or the password is // marked as deleted. The method returns in constant-ish time, regardless of whether the user exists or
// correct or incorrect. // the password is correct or incorrect.
func (a *Manager) Authenticate(username, password string) (*User, error) { func (a *Manager) Authenticate(username, password string) (*User, error) {
if username == Everyone { if username == Everyone {
return nil, ErrUnauthenticated return nil, ErrUnauthenticated
@ -332,9 +336,12 @@ func (a *Manager) Authenticate(username, password string) (*User, error) {
log.Trace("authentication of user %s failed (1): %s", username, err.Error()) log.Trace("authentication of user %s failed (1): %s", username, err.Error())
bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks")) bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
return nil, ErrUnauthenticated return nil, ErrUnauthenticated
} } else if user.Deleted {
if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil { log.Trace("authentication of user %s failed (2): user marked deleted", username)
log.Trace("authentication of user %s failed (2): %s", username, err.Error()) bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
return nil, ErrUnauthenticated
} else if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil {
log.Trace("authentication of user %s failed (3): %s", username, err.Error())
return nil, ErrUnauthenticated return nil, ErrUnauthenticated
} }
return user, nil return user, nil
@ -415,7 +422,7 @@ func (a *Manager) RemoveToken(user *User) error {
if user.Token == "" { if user.Token == "" {
return ErrUnauthorized return ErrUnauthorized
} }
if _, err := a.db.Exec(deleteTokenQuery, user.Name, user.Token); err != nil { if _, err := a.db.Exec(deleteTokenQuery, user.ID, user.Token); err != nil {
return err return err
} }
return nil return nil
@ -429,6 +436,14 @@ func (a *Manager) RemoveExpiredTokens() error {
return nil return nil
} }
// RemoveDeletedUsers deletes all users that have been marked deleted for
func (a *Manager) RemoveDeletedUsers() error {
if _, err := a.db.Exec(deleteUsersMarkedQuery, time.Now().Unix()); err != nil {
return err
}
return nil
}
// ChangeSettings persists the user settings // ChangeSettings persists the user settings
func (a *Manager) ChangeSettings(user *User) error { func (a *Manager) ChangeSettings(user *User) error {
prefs, err := json.Marshal(user.Prefs) prefs, err := json.Marshal(user.Prefs)
@ -533,7 +548,7 @@ func (a *Manager) resolvePerms(base, perm Permission) error {
} }
// AddUser adds a user with the given username, password and role // AddUser adds a user with the given username, password and role
func (a *Manager) AddUser(username, password string, role Role, createdBy string) error { func (a *Manager) AddUser(username, password string, role Role) error {
if !AllowedUsername(username) || !AllowedRole(role) { if !AllowedUsername(username) || !AllowedRole(role) {
return ErrInvalidArgument return ErrInvalidArgument
} }
@ -543,7 +558,7 @@ func (a *Manager) AddUser(username, password string, role Role, createdBy string
} }
userID := util.RandomStringPrefix(userIDPrefix, userIDLength) userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
syncTopic, now := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength), time.Now().Unix() syncTopic, now := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength), time.Now().Unix()
if _, err = a.db.Exec(insertUserQuery, userID, username, hash, role, syncTopic, createdBy, now); err != nil { if _, err = a.db.Exec(insertUserQuery, userID, username, hash, role, syncTopic, now); err != nil {
return err return err
} }
return nil return nil
@ -562,6 +577,29 @@ func (a *Manager) RemoveUser(username string) error {
return nil return nil
} }
// MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens. This prevents
// successful auth via Authenticate. A background process will delete the user at a later date.
func (a *Manager) MarkUserRemoved(user *User) error {
if !AllowedUsername(user.Name) {
return ErrInvalidArgument
}
tx, err := a.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := a.db.Exec(deleteUserAccessQuery, user.Name, user.Name); err != nil {
return err
}
if _, err := tx.Exec(deleteAllTokenQuery, user.ID); err != nil {
return err
}
if _, err := tx.Exec(updateUserDeletedQuery, time.Now().Add(userHardDeleteAfterDuration).Unix(), user.ID); err != nil {
return err
}
return tx.Commit()
}
// Users returns a list of users. It always also returns the Everyone user ("*"). // Users returns a list of users. It always also returns the Everyone user ("*").
func (a *Manager) Users() ([]*User, error) { func (a *Manager) Users() ([]*User, error) {
rows, err := a.db.Query(selectUsernamesQuery) rows, err := a.db.Query(selectUsernamesQuery)
@ -632,11 +670,11 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
var id, username, hash, role, prefs, syncTopic string var id, username, hash, role, prefs, syncTopic string
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString
var messages, emails int64 var messages, emails int64
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt sql.NullInt64 var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
if !rows.Next() { if !rows.Next() {
return nil, ErrUserNotFound return nil, ErrUserNotFound
} }
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil { if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
return nil, err return nil, err
} else if err := rows.Err(); err != nil { } else if err := rows.Err(); err != nil {
return nil, err return nil, err
@ -659,6 +697,7 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero
StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero
}, },
Deleted: deleted.Valid,
} }
if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil { if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
return nil, err return nil, err

View file

@ -13,8 +13,8 @@ const minBcryptTimingMillis = int64(50) // Ideally should be >100ms, but this sh
func TestManager_FullScenario_Default_DenyAll(t *testing.T) { func TestManager_FullScenario_Default_DenyAll(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, "unit-test")) require.Nil(t, a.AddUser("phil", "phil", RoleAdmin))
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) require.Nil(t, a.AddUser("ben", "ben", RoleUser))
require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite)) require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite))
require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead)) require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead))
require.Nil(t, a.AllowAccess("ben", "writeme", PermissionWrite)) require.Nil(t, a.AllowAccess("ben", "writeme", PermissionWrite))
@ -92,20 +92,20 @@ func TestManager_FullScenario_Default_DenyAll(t *testing.T) {
func TestManager_AddUser_Invalid(t *testing.T) { func TestManager_AddUser_Invalid(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) a := newTestManager(t, PermissionDenyAll)
require.Equal(t, ErrInvalidArgument, a.AddUser(" invalid ", "pass", RoleAdmin, "unit-test")) require.Equal(t, ErrInvalidArgument, a.AddUser(" invalid ", "pass", RoleAdmin))
require.Equal(t, ErrInvalidArgument, a.AddUser("validuser", "pass", "invalid-role", "unit-test")) require.Equal(t, ErrInvalidArgument, a.AddUser("validuser", "pass", "invalid-role"))
} }
func TestManager_AddUser_Timing(t *testing.T) { func TestManager_AddUser_Timing(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) a := newTestManager(t, PermissionDenyAll)
start := time.Now().UnixMilli() start := time.Now().UnixMilli()
require.Nil(t, a.AddUser("user", "pass", RoleAdmin, "unit-test")) require.Nil(t, a.AddUser("user", "pass", RoleAdmin))
require.GreaterOrEqual(t, time.Now().UnixMilli()-start, minBcryptTimingMillis) require.GreaterOrEqual(t, time.Now().UnixMilli()-start, minBcryptTimingMillis)
} }
func TestManager_Authenticate_Timing(t *testing.T) { func TestManager_Authenticate_Timing(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("user", "pass", RoleAdmin, "unit-test")) require.Nil(t, a.AddUser("user", "pass", RoleAdmin))
// Timing a correct attempt // Timing a correct attempt
start := time.Now().UnixMilli() start := time.Now().UnixMilli()
@ -126,10 +126,60 @@ func TestManager_Authenticate_Timing(t *testing.T) {
require.GreaterOrEqual(t, time.Now().UnixMilli()-start, minBcryptTimingMillis) require.GreaterOrEqual(t, time.Now().UnixMilli()-start, minBcryptTimingMillis)
} }
func TestManager_MarkUserRemoved_RemoveDeletedUsers(t *testing.T) {
a := newTestManager(t, PermissionDenyAll)
// Create user, add reservations and token
require.Nil(t, a.AddUser("user", "pass", RoleAdmin))
require.Nil(t, a.AddReservation("user", "mytopic", PermissionRead))
u, err := a.User("user")
require.Nil(t, err)
require.False(t, u.Deleted)
token, err := a.CreateToken(u)
require.Nil(t, err)
u, err = a.Authenticate("user", "pass")
require.Nil(t, err)
_, err = a.AuthenticateToken(token.Value)
require.Nil(t, err)
reservations, err := a.Reservations("user")
require.Nil(t, err)
require.Equal(t, 1, len(reservations))
// Mark deleted: cannot auth anymore, and all reservations are gone
require.Nil(t, a.MarkUserRemoved(u))
_, err = a.Authenticate("user", "pass")
require.Equal(t, ErrUnauthenticated, err)
_, err = a.AuthenticateToken(token.Value)
require.Equal(t, ErrUnauthenticated, err)
reservations, err = a.Reservations("user")
require.Nil(t, err)
require.Equal(t, 0, len(reservations))
// Make sure user is still there
u, err = a.User("user")
require.Nil(t, err)
require.True(t, u.Deleted)
_, err = a.db.Exec("UPDATE user SET deleted = ? WHERE id = ?", time.Now().Add(-1*(userHardDeleteAfterDuration+time.Hour)).Unix(), u.ID)
require.Nil(t, err)
require.Nil(t, a.RemoveDeletedUsers())
_, err = a.User("user")
require.Equal(t, ErrUserNotFound, err)
}
func TestManager_UserManagement(t *testing.T) { func TestManager_UserManagement(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, "unit-test")) require.Nil(t, a.AddUser("phil", "phil", RoleAdmin))
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) require.Nil(t, a.AddUser("ben", "ben", RoleUser))
require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite)) require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite))
require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead)) require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead))
require.Nil(t, a.AllowAccess("ben", "writeme", PermissionWrite)) require.Nil(t, a.AllowAccess("ben", "writeme", PermissionWrite))
@ -219,7 +269,7 @@ func TestManager_UserManagement(t *testing.T) {
func TestManager_ChangePassword(t *testing.T) { func TestManager_ChangePassword(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, "unit-test")) require.Nil(t, a.AddUser("phil", "phil", RoleAdmin))
_, err := a.Authenticate("phil", "phil") _, err := a.Authenticate("phil", "phil")
require.Nil(t, err) require.Nil(t, err)
@ -233,7 +283,7 @@ func TestManager_ChangePassword(t *testing.T) {
func TestManager_ChangeRole(t *testing.T) { func TestManager_ChangeRole(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) require.Nil(t, a.AddUser("ben", "ben", RoleUser))
require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite)) require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite))
require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead)) require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead))
@ -258,7 +308,7 @@ func TestManager_ChangeRole(t *testing.T) {
func TestManager_Reservations(t *testing.T) { func TestManager_Reservations(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) require.Nil(t, a.AddUser("ben", "ben", RoleUser))
require.Nil(t, a.AddReservation("ben", "ztopic", PermissionDenyAll)) require.Nil(t, a.AddReservation("ben", "ztopic", PermissionDenyAll))
require.Nil(t, a.AddReservation("ben", "readme", PermissionRead)) require.Nil(t, a.AddReservation("ben", "readme", PermissionRead))
require.Nil(t, a.AllowAccess("ben", "something-else", PermissionRead)) require.Nil(t, a.AllowAccess("ben", "something-else", PermissionRead))
@ -292,7 +342,7 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
AttachmentTotalSizeLimit: 524288000, AttachmentTotalSizeLimit: 524288000,
AttachmentExpiryDuration: 24 * time.Hour, AttachmentExpiryDuration: 24 * time.Hour,
})) }))
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) require.Nil(t, a.AddUser("ben", "ben", RoleUser))
require.Nil(t, a.ChangeTier("ben", "pro")) require.Nil(t, a.ChangeTier("ben", "pro"))
require.Nil(t, a.AddReservation("ben", "mytopic", PermissionDenyAll)) require.Nil(t, a.AddReservation("ben", "mytopic", PermissionDenyAll))
@ -340,7 +390,7 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
func TestManager_Token_Valid(t *testing.T) { func TestManager_Token_Valid(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) require.Nil(t, a.AddUser("ben", "ben", RoleUser))
u, err := a.User("ben") u, err := a.User("ben")
require.Nil(t, err) require.Nil(t, err)
@ -365,7 +415,7 @@ func TestManager_Token_Valid(t *testing.T) {
func TestManager_Token_Invalid(t *testing.T) { func TestManager_Token_Invalid(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) require.Nil(t, a.AddUser("ben", "ben", RoleUser))
u, err := a.AuthenticateToken(strings.Repeat("x", 32)) // 32 == token length u, err := a.AuthenticateToken(strings.Repeat("x", 32)) // 32 == token length
require.Nil(t, u) require.Nil(t, u)
@ -378,7 +428,7 @@ func TestManager_Token_Invalid(t *testing.T) {
func TestManager_Token_Expire(t *testing.T) { func TestManager_Token_Expire(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) require.Nil(t, a.AddUser("ben", "ben", RoleUser))
u, err := a.User("ben") u, err := a.User("ben")
require.Nil(t, err) require.Nil(t, err)
@ -426,7 +476,7 @@ func TestManager_Token_Expire(t *testing.T) {
func TestManager_Token_Extend(t *testing.T) { func TestManager_Token_Extend(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) require.Nil(t, a.AddUser("ben", "ben", RoleUser))
// Try to extend token for user without token // Try to extend token for user without token
u, err := a.User("ben") u, err := a.User("ben")
@ -453,7 +503,7 @@ func TestManager_Token_Extend(t *testing.T) {
func TestManager_Token_MaxCount_AutoDelete(t *testing.T) { func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) require.Nil(t, a.AddUser("ben", "ben", RoleUser))
// Try to extend token for user without token // Try to extend token for user without token
u, err := a.User("ben") u, err := a.User("ben")
@ -497,7 +547,7 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
func TestManager_EnqueueStats(t *testing.T) { func TestManager_EnqueueStats(t *testing.T) {
a, err := newManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond) a, err := newManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond)
require.Nil(t, err) require.Nil(t, err)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) require.Nil(t, a.AddUser("ben", "ben", RoleUser))
// Baseline: No messages or emails // Baseline: No messages or emails
u, err := a.User("ben") u, err := a.User("ben")
@ -527,7 +577,7 @@ func TestManager_EnqueueStats(t *testing.T) {
func TestManager_ChangeSettings(t *testing.T) { func TestManager_ChangeSettings(t *testing.T) {
a, err := newManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond) a, err := newManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond)
require.Nil(t, err) require.Nil(t, err)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) require.Nil(t, a.AddUser("ben", "ben", RoleUser))
// No settings // No settings
u, err := a.User("ben") u, err := a.User("ben")

View file

@ -20,8 +20,7 @@ type User struct {
Stats *Stats Stats *Stats
Billing *Billing Billing *Billing
SyncTopic string SyncTopic string
Created time.Time Deleted bool
LastSeen time.Time
} }
// Auther is an interface for authentication and authorization // Auther is an interface for authentication and authorization
@ -186,7 +185,8 @@ const (
// Everyone is a special username representing anonymous users // Everyone is a special username representing anonymous users
const ( const (
Everyone = "*" Everyone = "*"
everyoneID = "u_everyone"
) )
var ( var (

View file

@ -324,3 +324,15 @@ func UnmarshalJSONWithLimit[T any](r io.ReadCloser, limit int) (*T, error) {
} }
return &obj, nil return &obj, nil
} }
// Retry executes function f until if succeeds, and then returns t. If f fails, it sleeps
// and tries again. The sleep durations are passed as the after params.
func Retry[T any](f func() (*T, error), after ...time.Duration) (t *T, err error) {
for _, delay := range after {
if t, err = f(); err == nil {
return t, nil
}
time.Sleep(delay)
}
return nil, err
}