Delayed deletion

This commit is contained in:
binwiederhier 2023-01-22 22:21:30 -05:00
parent 9c082a8331
commit 954d919361
14 changed files with 280 additions and 131 deletions

View File

@ -16,8 +16,7 @@ import (
)
const (
tierReset = "-"
createdByCLI = "cli"
tierReset = "-"
)
func init() {
@ -197,7 +196,7 @@ func execUserAdd(c *cli.Context) error {
password = p
}
if err := manager.AddUser(username, password, role, createdByCLI); err != nil {
if err := manager.AddUser(username, password, role); err != nil {
return err
}
fmt.Fprintf(c.App.ErrWriter, "user %s added with role %s\n", username, role)

View File

@ -39,12 +39,10 @@ TODO
--
- Reservation: Kill existing subscribers when topic is reserved (deadcade)
- Rate limiting: Sensitive endpoints (account/login/change-password/...)
- Stripe: Add metadata to customer
- Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
- Reservation (UI): Ask for confirmation when removing reservation (deadcade)
- Logging: Add detailed logging with username/customerID for all Stripe events (phil)
- Rate limiting: Sensitive endpoints (account/login/change-password/...)
- Stripe webhook: Do not respond wih error if user does not exist (after account deletion)
- Stripe: Add metadata to customer
races:
- v.user --> see publishSyncEventAsync() test
@ -53,7 +51,7 @@ payments:
- reconciliation
delete messages + reserved topics on ResetTier delete attachments in access.go
account deletion should delete messages and reservations and attachments
Limits & rate limiting:
rate limiting weirdness. wth is going on?
@ -1256,11 +1254,14 @@ func (s *Server) execManager() {
s.mu.Unlock()
log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors)
// Delete expired user tokens
// Delete expired user tokens and users
if s.userManager != nil {
if err := s.userManager.RemoveExpiredTokens(); err != nil {
log.Warn("Error expiring user tokens: %s", err.Error())
}
if err := s.userManager.RemoveDeletedUsers(); err != nil {
log.Warn("Error deleting soft-deleted users: %s", err.Error())
}
}
// Delete expired attachments
@ -1283,7 +1284,7 @@ func (s *Server) execManager() {
}
}
// DeleteMessages message cache
// Prune messages
log.Debug("Manager: Pruning messages")
expiredMessageIDs, err := s.messageCache.MessagesExpired()
if err != nil {

View File

@ -11,7 +11,6 @@ import (
const (
subscriptionIDLength = 16
createdByAPI = "api"
syncTopicAccountSyncEvent = "sync"
)
@ -34,7 +33,7 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
if v.accountLimiter != nil && !v.accountLimiter.Allow() {
return errHTTPTooManyRequestsLimitAccountCreation
}
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser, createdByAPI); err != nil { // TODO this should return a User
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { // TODO this should return a User
return err
}
return s.writeJSON(w, newSuccessResponse())
@ -118,18 +117,20 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis
return s.writeJSON(w, response)
}
func (s *Server) handleAccountDelete(w http.ResponseWriter, _ *http.Request, v *visitor) error {
func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user.Billing.StripeSubscriptionID != "" {
log.Info("Deleting user %s (billing customer: %s, billing subscription: %s)", v.user.Name, v.user.Billing.StripeCustomerID, v.user.Billing.StripeSubscriptionID)
log.Info("%s Canceling billing subscription %s", logHTTPPrefix(v, r), v.user.Billing.StripeSubscriptionID)
if v.user.Billing.StripeSubscriptionID != "" {
if _, err := s.stripe.CancelSubscription(v.user.Billing.StripeSubscriptionID); err != nil {
return err
}
}
} else {
log.Info("Deleting user %s", v.user.Name)
if err := s.maybeRemoveExcessReservations(logHTTPPrefix(v, r), v.user, 0); err != nil {
return err
}
}
if err := s.userManager.RemoveUser(v.user.Name); err != nil {
log.Info("%s Marking user %s as deleted", logHTTPPrefix(v, r), v.user.Name)
if err := s.userManager.MarkUserRemoved(v.user); err != nil {
return err
}
return s.writeJSON(w, newSuccessResponse())

View File

@ -67,8 +67,8 @@ func TestAccount_Signup_AsUser(t *testing.T) {
conf.EnableSignup = true
s := newTestServer(t, conf)
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test"))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
rr := request(t, s, "POST", "/v1/account", `{"username":"emma", "password":"emma"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
@ -133,7 +133,7 @@ func TestAccount_Get_Anonymous(t *testing.T) {
func TestAccount_ChangeSettings(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
user, _ := s.userManager.User("phil")
token, _ := s.userManager.CreateToken(user)
@ -160,7 +160,7 @@ func TestAccount_ChangeSettings(t *testing.T) {
func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
rr := request(t, s, "POST", "/v1/account/subscription", `{"base_url": "http://abc.com", "topic": "def"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
@ -210,7 +210,7 @@ func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
func TestAccount_ChangePassword(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
rr := request(t, s, "POST", "/v1/account/password", `{"password": "phil", "new_password": "new password"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
@ -237,7 +237,7 @@ func TestAccount_ChangePassword_NoAccount(t *testing.T) {
func TestAccount_ExtendToken(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
rr := request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
@ -260,7 +260,7 @@ func TestAccount_ExtendToken(t *testing.T) {
func TestAccount_ExtendToken_NoTokenProvided(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
rr := request(t, s, "PATCH", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"), // Not Bearer!
@ -271,7 +271,7 @@ func TestAccount_ExtendToken_NoTokenProvided(t *testing.T) {
func TestAccount_DeleteToken(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
rr := request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
@ -324,10 +324,15 @@ func TestAccount_Delete_Success(t *testing.T) {
})
require.Equal(t, 200, rr.Code)
// Account was marked deleted
rr = request(t, s, "GET", "/v1/account", "", map[string]string{
"Authorization": util.BasicAuth("phil", "mypass"),
})
require.Equal(t, 401, rr.Code)
// Cannot re-create account, since still exists
rr = request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil)
require.Equal(t, 409, rr.Code)
}
func TestAccount_Delete_Not_Allowed(t *testing.T) {
@ -360,7 +365,7 @@ func TestAccount_Reservation_AddAdminSuccess(t *testing.T) {
conf := newTestConfigWithAuthFile(t)
conf.EnableSignup = true
s := newTestServer(t, conf)
require.Nil(t, s.userManager.AddUser("phil", "adminpass", user.RoleAdmin, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "adminpass", user.RoleAdmin))
rr := request(t, s, "POST", "/v1/account/reservation", `{"topic":"mytopic","everyone":"deny-all"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "adminpass"),

View File

@ -18,15 +18,10 @@ import (
"io"
"net/http"
"net/netip"
"strings"
"time"
)
var (
errNotAPaidTier = errors.New("tier does not have billing price identifier")
errMultipleBillingSubscriptions = errors.New("cannot have multiple billing subscriptions")
errNoBillingSubscription = errors.New("user does not have an active billing subscription")
)
// Payments in ntfy are done via Stripe.
//
// Pretty much all payments related things are in this file. The following processes
@ -49,6 +44,16 @@ var (
// This is used to keep the local user database fields up to date. Stripe is the source of truth.
// What Stripe says is mirrored and not questioned.
var (
errNotAPaidTier = errors.New("tier does not have billing price identifier")
errMultipleBillingSubscriptions = errors.New("cannot have multiple billing subscriptions")
errNoBillingSubscription = errors.New("user does not have an active billing subscription")
)
var (
retryUserDelays = []time.Duration{3 * time.Second, 5 * time.Second, 7 * time.Second}
)
// handleBillingTiersGet returns all available paid tiers, and the free tier. This is to populate the upgrade dialog
// in the UI. Note that this endpoint does NOT have a user context (no v.user!).
func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
@ -114,7 +119,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
} else if tier.StripePriceID == "" {
return errNotAPaidTier
}
log.Info("Stripe: No existing subscription, creating checkout flow")
log.Info("%s Creating Stripe checkout flow", logHTTPPrefix(v, r))
var stripeCustomerID *string
if v.user.Billing.StripeCustomerID != "" {
stripeCustomerID = &v.user.Billing.StripeCustomerID
@ -138,9 +143,6 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
Quantity: stripe.Int64(1),
},
},
/*AutomaticTax: &stripe.CheckoutSessionAutomaticTaxParams{
Enabled: stripe.Bool(true),
},*/
}
sess, err := s.stripe.NewCheckoutSession(params)
if err != nil {
@ -155,7 +157,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
// handleAccountBillingSubscriptionCreateSuccess is called after the Stripe checkout session has succeeded. We use
// the session ID in the URL to retrieve the Stripe subscription and update the local database. This is the first
// and only time we can map the local username with the Stripe customer ID.
func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error {
func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, v *visitor) error {
// We don't have a v.user in this endpoint, only a userManager!
matches := apiAccountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
if len(matches) != 2 {
@ -182,7 +184,8 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
if err != nil {
return err
}
if err := s.updateSubscriptionAndTier(u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
v.SetUser(u)
if err := s.updateSubscriptionAndTier(logHTTPPrefix(v, r), u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
return err
}
http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther)
@ -203,7 +206,7 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
if err != nil {
return err
}
log.Info("Stripe: Changing tier and subscription to %s", tier.Code)
log.Info("%s Changing billing tier to %s (price %s) for subscription %s", logHTTPPrefix(v, r), tier.Code, tier.StripePriceID, v.user.Billing.StripeSubscriptionID)
sub, err := s.stripe.GetSubscription(v.user.Billing.StripeSubscriptionID)
if err != nil {
return err
@ -228,6 +231,7 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
// and cancelling the Stripe subscription entirely
func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
log.Info("%s Deleting billing subscription %s", logHTTPPrefix(v, r), v.user.Billing.StripeSubscriptionID)
if v.user.Billing.StripeSubscriptionID != "" {
params := &stripe.SubscriptionParams{
CancelAtPeriodEnd: stripe.Bool(true),
@ -246,6 +250,7 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter,
if v.user.Billing.StripeCustomerID == "" {
return errHTTPBadRequestNotAPaidUser
}
log.Info("%s Creating billing portal session", logHTTPPrefix(v, r))
params := &stripe.BillingPortalSessionParams{
Customer: stripe.String(v.user.Billing.StripeCustomerID),
ReturnURL: stripe.String(s.config.BaseURL),
@ -280,28 +285,30 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ
} else if event.Data == nil || event.Data.Raw == nil {
return errHTTPBadRequestBillingRequestInvalid
}
log.Info("Stripe: webhook event %s received", event.Type)
switch event.Type {
case "customer.subscription.updated":
return s.handleAccountBillingWebhookSubscriptionUpdated(event.Data.Raw)
case "customer.subscription.deleted":
return s.handleAccountBillingWebhookSubscriptionDeleted(event.Data.Raw)
default:
log.Warn("STRIPE Unhandled webhook event %s received", event.Type)
return nil
}
}
func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMessage) error {
r, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event)))
ev, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event)))
if err != nil {
return err
} else if r.ID == "" || r.Customer == "" || r.Status == "" || r.CurrentPeriodEnd == 0 || r.Items == nil || len(r.Items.Data) != 1 || r.Items.Data[0].Price == nil || r.Items.Data[0].Price.ID == "" {
} else if ev.ID == "" || ev.Customer == "" || ev.Status == "" || ev.CurrentPeriodEnd == 0 || ev.Items == nil || len(ev.Items.Data) != 1 || ev.Items.Data[0].Price == nil || ev.Items.Data[0].Price.ID == "" {
return errHTTPBadRequestBillingRequestInvalid
}
subscriptionID, priceID := r.ID, r.Items.Data[0].Price.ID
log.Info("Stripe: customer %s: Updating subscription to status %s, with price %s", r.Customer, r.Status, priceID)
u, err := s.userManager.UserByStripeCustomer(r.Customer)
subscriptionID, priceID := ev.ID, ev.Items.Data[0].Price.ID
log.Info("%s Updating subscription to status %s, with price %s", logStripePrefix(ev.Customer, ev.ID), ev.Status, priceID)
userFn := func() (*user.User, error) {
return s.userManager.UserByStripeCustomer(ev.Customer)
}
u, err := util.Retry[user.User](userFn, retryUserDelays...)
if err != nil {
return err
}
@ -309,7 +316,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
if err != nil {
return err
}
if err := s.updateSubscriptionAndTier(u, tier, r.Customer, subscriptionID, r.Status, r.CurrentPeriodEnd, r.CancelAt); err != nil {
if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
return err
}
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
@ -317,47 +324,56 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
}
func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error {
r, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event)))
ev, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event)))
if err != nil {
return err
} else if r.Customer == "" {
} else if ev.Customer == "" {
return errHTTPBadRequestBillingRequestInvalid
}
log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", r.Customer)
u, err := s.userManager.UserByStripeCustomer(r.Customer)
log.Info("%s Subscription deleted, downgrading to unpaid tier", logStripePrefix(ev.Customer, ev.ID))
u, err := s.userManager.UserByStripeCustomer(ev.Customer)
if err != nil {
return err
}
if err := s.updateSubscriptionAndTier(u, nil, r.Customer, "", "", 0, 0); err != nil {
if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, nil, ev.Customer, "", "", 0, 0); err != nil {
return err
}
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
return nil
}
func (s *Server) updateSubscriptionAndTier(u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
// Remove excess reservations (if too many for tier), and mark associated messages deleted
// maybeRemoveExcessReservations deletes topic reservations for the given user (if too many for tier),
// and marks associated messages for the topics as deleted. This also eventually deletes attachments.
// The process relies on the manager to perform the actual deletions (see runManager).
func (s *Server) maybeRemoveExcessReservations(logPrefix string, u *user.User, reservationsLimit int64) error {
reservations, err := s.userManager.Reservations(u.Name)
if err != nil {
return err
} else if int64(len(reservations)) <= reservationsLimit {
return nil
}
topics := make([]string, 0)
for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- {
topics = append(topics, reservations[i].Topic)
}
log.Info("%s Removing excess reservations for topics %s", logPrefix, strings.Join(topics, ", "))
if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil {
return err
}
if err := s.messageCache.ExpireMessages(topics...); err != nil {
return err
}
return nil
}
func (s *Server) updateSubscriptionAndTier(logPrefix string, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
reservationsLimit := visitorDefaultReservationsLimit
if tier != nil {
reservationsLimit = tier.ReservationsLimit
}
if int64(len(reservations)) > reservationsLimit {
topics := make([]string, 0)
for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- {
topics = append(topics, reservations[i].Topic)
}
if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil {
return err
}
if err := s.messageCache.ExpireMessages(topics...); err != nil {
return err
}
if err := s.maybeRemoveExcessReservations(logPrefix, u, reservationsLimit); err != nil {
return err
}
// Change or remove tier
if tier == nil {
if err := s.userManager.ResetTier(u.Name); err != nil {
return err

View File

@ -34,7 +34,7 @@ func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
Code: "pro",
StripePriceID: "price_123",
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
// Create subscription
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
@ -69,7 +69,7 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
Code: "pro",
StripePriceID: "price_123",
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
u, err := s.userManager.User("phil")
require.Nil(t, err)
@ -110,7 +110,7 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
Code: "pro",
StripePriceID: "price_123",
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
u, err := s.userManager.User("phil")
require.Nil(t, err)
@ -174,7 +174,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
AttachmentFileSizeLimit: 1000000,
AttachmentTotalSizeLimit: 1000000,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
require.Nil(t, s.userManager.AddReservation("phil", "ztopic", user.PermissionDenyAll))

View File

@ -625,7 +625,7 @@ func TestServer_Auth_Success_Admin(t *testing.T) {
c := newTestConfigWithAuthFile(t)
s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
@ -639,7 +639,7 @@ func TestServer_Auth_Success_User(t *testing.T) {
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AllowAccess("ben", "mytopic", user.PermissionReadWrite))
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
@ -653,7 +653,7 @@ func TestServer_Auth_Success_User_MultipleTopics(t *testing.T) {
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AllowAccess("ben", "mytopic", user.PermissionReadWrite))
require.Nil(t, s.userManager.AllowAccess("ben", "anothertopic", user.PermissionReadWrite))
@ -674,7 +674,7 @@ func TestServer_Auth_Fail_InvalidPass(t *testing.T) {
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
"Authorization": util.BasicAuth("phil", "INVALID"),
@ -687,7 +687,7 @@ func TestServer_Auth_Fail_Unauthorized(t *testing.T) {
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AllowAccess("ben", "sometopic", user.PermissionReadWrite)) // Not mytopic!
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
@ -701,7 +701,7 @@ func TestServer_Auth_Fail_CannotPublish(t *testing.T) {
c.AuthDefault = user.PermissionReadWrite // Open by default
s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, s.userManager.AllowAccess(user.Everyone, "private", user.PermissionDenyAll))
require.Nil(t, s.userManager.AllowAccess(user.Everyone, "announcements", user.PermissionRead))
@ -731,7 +731,7 @@ func TestServer_Auth_ViaQuery(t *testing.T) {
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("ben", "some pass", user.RoleAdmin, "unit-test"))
require.Nil(t, s.userManager.AddUser("ben", "some pass", user.RoleAdmin))
u := fmt.Sprintf("/mytopic/json?poll=1&auth=%s", base64.RawURLEncoding.EncodeToString([]byte(util.BasicAuth("ben", "some pass"))))
response := request(t, s, "GET", u, "", nil)
@ -749,7 +749,7 @@ func TestServer_StatsResetter(t *testing.T) {
s := newTestServer(t, c)
go s.runStatsResetter()
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AllowAccess("phil", "mytopic", user.PermissionReadWrite))
for i := 0; i < 5; i++ {
@ -1137,7 +1137,7 @@ func TestServer_PublishWithTierBasedMessageLimitAndExpiry(t *testing.T) {
MessagesLimit: 5,
MessagesExpiryDuration: -5 * time.Second, // Second, what a hack!
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
// Publish to reach message limit
@ -1369,7 +1369,7 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) {
AttachmentTotalSizeLimit: 200_000,
AttachmentExpiryDuration: sevenDays, // 7 days
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
// Publish and make sure we can retrieve it
@ -1413,7 +1413,7 @@ func TestServer_PublishAttachmentWithTierBasedLimits(t *testing.T) {
AttachmentTotalSizeLimit: 200_000,
AttachmentExpiryDuration: 30 * time.Second,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
// Publish small file as anonymous

View File

@ -354,5 +354,6 @@ type apiStripeSubscriptionUpdatedEvent struct {
}
type apiStripeSubscriptionDeletedEvent struct {
ID string `json:"id"`
Customer string `json:"customer"`
}

View File

@ -49,7 +49,7 @@ func readQueryParam(r *http.Request, names ...string) string {
}
func logMessagePrefix(v *visitor, m *message) string {
return fmt.Sprintf("%s/%s/%s", v.ip, m.Topic, m.ID)
return fmt.Sprintf("%s/%s/%s", v.String(), m.Topic, m.ID)
}
func logHTTPPrefix(v *visitor, r *http.Request) string {
@ -57,7 +57,14 @@ func logHTTPPrefix(v *visitor, r *http.Request) string {
if requestURI == "" {
requestURI = r.URL.Path
}
return fmt.Sprintf("%s HTTP %s %s", v.ip, r.Method, requestURI)
return fmt.Sprintf("%s HTTP %s %s", v.String(), r.Method, requestURI)
}
func logStripePrefix(customerID, subscriptionID string) string {
if subscriptionID != "" {
return fmt.Sprintf("%s/%s STRIPE", customerID, subscriptionID)
}
return fmt.Sprintf("%s STRIPE", customerID)
}
func logSMTPPrefix(state *smtp.ConnectionState) string {

View File

@ -2,6 +2,7 @@ package server
import (
"errors"
"fmt"
"heckel.io/ntfy/user"
"net/netip"
"sync"
@ -119,6 +120,17 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
}
}
func (v *visitor) String() string {
v.mu.Lock()
defer v.mu.Unlock()
if v.user != nil && v.user.Billing.StripeCustomerID != "" {
return fmt.Sprintf("%s/%s/%s", v.ip.String(), v.user.ID, v.user.Billing.StripeCustomerID)
} else if v.user != nil {
return fmt.Sprintf("%s/%s", v.ip.String(), v.user.ID)
}
return v.ip.String()
}
func (v *visitor) RequestAllowed() error {
if !v.requestLimiter.Allow() {
return errVisitorLimitReached
@ -216,6 +228,12 @@ func (v *visitor) ResetStats() {
}
}
func (v *visitor) SetUser(u *user.User) {
v.mu.Lock()
defer v.mu.Unlock()
v.user = u
}
func (v *visitor) Limits() *visitorLimits {
v.mu.Lock()
defer v.mu.Unlock()

View File

@ -25,6 +25,7 @@ const (
userPasswordBcryptCost = 10
userAuthIntentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match userPasswordBcryptCost
userStatsQueueWriterInterval = 33 * time.Second
userHardDeleteAfterDuration = 7 * 24 * time.Hour
tokenPrefix = "tk_"
tokenLength = 32
tokenMaxCount = 10 // Only keep this many tokens in the table per user
@ -57,7 +58,7 @@ const (
CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
CREATE TABLE IF NOT EXISTS user (
id TEXT PRIMARY KEY,
tier_id INT,
tier_id TEXT,
user TEXT NOT NULL,
pass TEXT NOT NULL,
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
@ -70,8 +71,8 @@ const (
stripe_subscription_status TEXT,
stripe_subscription_paid_until INT,
stripe_subscription_cancel_at INT,
created_by TEXT NOT NULL,
created_at INT NOT NULL,
created INT NOT NULL,
deleted INT,
FOREIGN KEY (tier_id) REFERENCES tier (id)
);
CREATE UNIQUE INDEX idx_user ON user (user);
@ -98,8 +99,8 @@ const (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at)
VALUES ('u_everyone', '*', '', 'anonymous', '', 'system', UNIXEPOCH())
INSERT INTO user (id, user, pass, role, sync_topic, created)
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', UNIXEPOCH())
ON CONFLICT (id) DO NOTHING;
`
createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;`
@ -108,26 +109,26 @@ const (
`
selectUserByIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.id = ?
`
selectUserByNameQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE user = ?
WHERE user = ?
`
selectUserByTokenQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
FROM user u
JOIN user_token t on u.id = t.user_id
LEFT JOIN tier t on t.id = u.tier_id
WHERE t.token = ? AND t.expires >= ?
`
selectUserByStripeCustomerIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.stripe_customer_id = ?
@ -141,8 +142,8 @@ const (
`
insertUserQuery = `
INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
INSERT INTO user (id, user, pass, role, sync_topic, created)
VALUES (?, ?, ?, ?, ?, ?)
`
selectUsernamesQuery = `
SELECT user
@ -159,6 +160,8 @@ const (
updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE user = ?`
updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ? WHERE id = ?`
updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0`
updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
deleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
deleteUserQuery = `DELETE FROM user WHERE user = ?`
upsertUserAccessQuery = `
@ -214,7 +217,8 @@ const (
selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES (?, ?, ?)`
updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
deleteTokenQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires < ?`
deleteExcessTokensQuery = `
DELETE FROM user_token
@ -268,8 +272,8 @@ const (
`
migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
migrate1To2InsertUserNoTx = `
INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at)
SELECT ?, user, pass, role, ?, 'admin', UNIXEPOCH() FROM user_old WHERE user = ?
INSERT INTO user (id, user, pass, role, sync_topic, created)
SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
`
migrate1To2InsertFromOldTablesAndDropNoTx = `
INSERT INTO user_access (user_id, topic, read, write)
@ -320,9 +324,9 @@ func newManager(filename, startupQueries string, defaultAccess Permission, stats
return manager, nil
}
// Authenticate checks username and password and returns a User if correct. The method
// returns in constant-ish time, regardless of whether the user exists or the password is
// correct or incorrect.
// Authenticate checks username and password and returns a User if correct, and the user has not been
// marked as deleted. The method returns in constant-ish time, regardless of whether the user exists or
// the password is correct or incorrect.
func (a *Manager) Authenticate(username, password string) (*User, error) {
if username == Everyone {
return nil, ErrUnauthenticated
@ -332,9 +336,12 @@ func (a *Manager) Authenticate(username, password string) (*User, error) {
log.Trace("authentication of user %s failed (1): %s", username, err.Error())
bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
return nil, ErrUnauthenticated
}
if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil {
log.Trace("authentication of user %s failed (2): %s", username, err.Error())
} else if user.Deleted {
log.Trace("authentication of user %s failed (2): user marked deleted", username)
bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
return nil, ErrUnauthenticated
} else if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil {
log.Trace("authentication of user %s failed (3): %s", username, err.Error())
return nil, ErrUnauthenticated
}
return user, nil
@ -415,7 +422,7 @@ func (a *Manager) RemoveToken(user *User) error {
if user.Token == "" {
return ErrUnauthorized
}
if _, err := a.db.Exec(deleteTokenQuery, user.Name, user.Token); err != nil {
if _, err := a.db.Exec(deleteTokenQuery, user.ID, user.Token); err != nil {
return err
}
return nil
@ -429,6 +436,14 @@ func (a *Manager) RemoveExpiredTokens() error {
return nil
}
// RemoveDeletedUsers deletes all users that have been marked deleted for
func (a *Manager) RemoveDeletedUsers() error {
if _, err := a.db.Exec(deleteUsersMarkedQuery, time.Now().Unix()); err != nil {
return err
}
return nil
}
// ChangeSettings persists the user settings
func (a *Manager) ChangeSettings(user *User) error {
prefs, err := json.Marshal(user.Prefs)
@ -533,7 +548,7 @@ func (a *Manager) resolvePerms(base, perm Permission) error {
}
// AddUser adds a user with the given username, password and role
func (a *Manager) AddUser(username, password string, role Role, createdBy string) error {
func (a *Manager) AddUser(username, password string, role Role) error {
if !AllowedUsername(username) || !AllowedRole(role) {
return ErrInvalidArgument
}
@ -543,7 +558,7 @@ func (a *Manager) AddUser(username, password string, role Role, createdBy string
}
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
syncTopic, now := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength), time.Now().Unix()
if _, err = a.db.Exec(insertUserQuery, userID, username, hash, role, syncTopic, createdBy, now); err != nil {
if _, err = a.db.Exec(insertUserQuery, userID, username, hash, role, syncTopic, now); err != nil {
return err
}
return nil
@ -562,6 +577,29 @@ func (a *Manager) RemoveUser(username string) error {
return nil
}
// MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens. This prevents
// successful auth via Authenticate. A background process will delete the user at a later date.
func (a *Manager) MarkUserRemoved(user *User) error {
if !AllowedUsername(user.Name) {
return ErrInvalidArgument
}
tx, err := a.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := a.db.Exec(deleteUserAccessQuery, user.Name, user.Name); err != nil {
return err
}
if _, err := tx.Exec(deleteAllTokenQuery, user.ID); err != nil {
return err
}
if _, err := tx.Exec(updateUserDeletedQuery, time.Now().Add(userHardDeleteAfterDuration).Unix(), user.ID); err != nil {
return err
}
return tx.Commit()
}
// Users returns a list of users. It always also returns the Everyone user ("*").
func (a *Manager) Users() ([]*User, error) {
rows, err := a.db.Query(selectUsernamesQuery)
@ -632,11 +670,11 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
var id, username, hash, role, prefs, syncTopic string
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString
var messages, emails int64
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt sql.NullInt64
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
if !rows.Next() {
return nil, ErrUserNotFound
}
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
@ -659,6 +697,7 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero
StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero
},
Deleted: deleted.Valid,
}
if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
return nil, err

View File

@ -13,8 +13,8 @@ const minBcryptTimingMillis = int64(50) // Ideally should be >100ms, but this sh
func TestManager_FullScenario_Default_DenyAll(t *testing.T) {
a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, "unit-test"))
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin))
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite))
require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead))
require.Nil(t, a.AllowAccess("ben", "writeme", PermissionWrite))
@ -92,20 +92,20 @@ func TestManager_FullScenario_Default_DenyAll(t *testing.T) {
func TestManager_AddUser_Invalid(t *testing.T) {
a := newTestManager(t, PermissionDenyAll)
require.Equal(t, ErrInvalidArgument, a.AddUser(" invalid ", "pass", RoleAdmin, "unit-test"))
require.Equal(t, ErrInvalidArgument, a.AddUser("validuser", "pass", "invalid-role", "unit-test"))
require.Equal(t, ErrInvalidArgument, a.AddUser(" invalid ", "pass", RoleAdmin))
require.Equal(t, ErrInvalidArgument, a.AddUser("validuser", "pass", "invalid-role"))
}
func TestManager_AddUser_Timing(t *testing.T) {
a := newTestManager(t, PermissionDenyAll)
start := time.Now().UnixMilli()
require.Nil(t, a.AddUser("user", "pass", RoleAdmin, "unit-test"))
require.Nil(t, a.AddUser("user", "pass", RoleAdmin))
require.GreaterOrEqual(t, time.Now().UnixMilli()-start, minBcryptTimingMillis)
}
func TestManager_Authenticate_Timing(t *testing.T) {
a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("user", "pass", RoleAdmin, "unit-test"))
require.Nil(t, a.AddUser("user", "pass", RoleAdmin))
// Timing a correct attempt
start := time.Now().UnixMilli()
@ -126,10 +126,60 @@ func TestManager_Authenticate_Timing(t *testing.T) {
require.GreaterOrEqual(t, time.Now().UnixMilli()-start, minBcryptTimingMillis)
}
func TestManager_MarkUserRemoved_RemoveDeletedUsers(t *testing.T) {
a := newTestManager(t, PermissionDenyAll)
// Create user, add reservations and token
require.Nil(t, a.AddUser("user", "pass", RoleAdmin))
require.Nil(t, a.AddReservation("user", "mytopic", PermissionRead))
u, err := a.User("user")
require.Nil(t, err)
require.False(t, u.Deleted)
token, err := a.CreateToken(u)
require.Nil(t, err)
u, err = a.Authenticate("user", "pass")
require.Nil(t, err)
_, err = a.AuthenticateToken(token.Value)
require.Nil(t, err)
reservations, err := a.Reservations("user")
require.Nil(t, err)
require.Equal(t, 1, len(reservations))
// Mark deleted: cannot auth anymore, and all reservations are gone
require.Nil(t, a.MarkUserRemoved(u))
_, err = a.Authenticate("user", "pass")
require.Equal(t, ErrUnauthenticated, err)
_, err = a.AuthenticateToken(token.Value)
require.Equal(t, ErrUnauthenticated, err)
reservations, err = a.Reservations("user")
require.Nil(t, err)
require.Equal(t, 0, len(reservations))
// Make sure user is still there
u, err = a.User("user")
require.Nil(t, err)
require.True(t, u.Deleted)
_, err = a.db.Exec("UPDATE user SET deleted = ? WHERE id = ?", time.Now().Add(-1*(userHardDeleteAfterDuration+time.Hour)).Unix(), u.ID)
require.Nil(t, err)
require.Nil(t, a.RemoveDeletedUsers())
_, err = a.User("user")
require.Equal(t, ErrUserNotFound, err)
}
func TestManager_UserManagement(t *testing.T) {
a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, "unit-test"))
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin))
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite))
require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead))
require.Nil(t, a.AllowAccess("ben", "writeme", PermissionWrite))
@ -219,7 +269,7 @@ func TestManager_UserManagement(t *testing.T) {
func TestManager_ChangePassword(t *testing.T) {
a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, "unit-test"))
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin))
_, err := a.Authenticate("phil", "phil")
require.Nil(t, err)
@ -233,7 +283,7 @@ func TestManager_ChangePassword(t *testing.T) {
func TestManager_ChangeRole(t *testing.T) {
a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite))
require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead))
@ -258,7 +308,7 @@ func TestManager_ChangeRole(t *testing.T) {
func TestManager_Reservations(t *testing.T) {
a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
require.Nil(t, a.AddReservation("ben", "ztopic", PermissionDenyAll))
require.Nil(t, a.AddReservation("ben", "readme", PermissionRead))
require.Nil(t, a.AllowAccess("ben", "something-else", PermissionRead))
@ -292,7 +342,7 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
AttachmentTotalSizeLimit: 524288000,
AttachmentExpiryDuration: 24 * time.Hour,
}))
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
require.Nil(t, a.ChangeTier("ben", "pro"))
require.Nil(t, a.AddReservation("ben", "mytopic", PermissionDenyAll))
@ -340,7 +390,7 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
func TestManager_Token_Valid(t *testing.T) {
a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
u, err := a.User("ben")
require.Nil(t, err)
@ -365,7 +415,7 @@ func TestManager_Token_Valid(t *testing.T) {
func TestManager_Token_Invalid(t *testing.T) {
a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
u, err := a.AuthenticateToken(strings.Repeat("x", 32)) // 32 == token length
require.Nil(t, u)
@ -378,7 +428,7 @@ func TestManager_Token_Invalid(t *testing.T) {
func TestManager_Token_Expire(t *testing.T) {
a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
u, err := a.User("ben")
require.Nil(t, err)
@ -426,7 +476,7 @@ func TestManager_Token_Expire(t *testing.T) {
func TestManager_Token_Extend(t *testing.T) {
a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
// Try to extend token for user without token
u, err := a.User("ben")
@ -453,7 +503,7 @@ func TestManager_Token_Extend(t *testing.T) {
func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
// Try to extend token for user without token
u, err := a.User("ben")
@ -497,7 +547,7 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
func TestManager_EnqueueStats(t *testing.T) {
a, err := newManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond)
require.Nil(t, err)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
// Baseline: No messages or emails
u, err := a.User("ben")
@ -527,7 +577,7 @@ func TestManager_EnqueueStats(t *testing.T) {
func TestManager_ChangeSettings(t *testing.T) {
a, err := newManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond)
require.Nil(t, err)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
// No settings
u, err := a.User("ben")

View File

@ -20,8 +20,7 @@ type User struct {
Stats *Stats
Billing *Billing
SyncTopic string
Created time.Time
LastSeen time.Time
Deleted bool
}
// Auther is an interface for authentication and authorization
@ -186,7 +185,8 @@ const (
// Everyone is a special username representing anonymous users
const (
Everyone = "*"
Everyone = "*"
everyoneID = "u_everyone"
)
var (

View File

@ -324,3 +324,15 @@ func UnmarshalJSONWithLimit[T any](r io.ReadCloser, limit int) (*T, error) {
}
return &obj, nil
}
// Retry executes function f until if succeeds, and then returns t. If f fails, it sleeps
// and tries again. The sleep durations are passed as the after params.
func Retry[T any](f func() (*T, error), after ...time.Duration) (t *T, err error) {
for _, delay := range after {
if t, err = f(); err == nil {
return t, nil
}
time.Sleep(delay)
}
return nil, err
}