From 9c082a83315e43962778513fe9607f3f5f0b67a9 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sat, 21 Jan 2023 23:15:22 -0500 Subject: [PATCH] Introduce text IDs for everything (esp user), to avoid security and accounting issues --- server/message_cache.go | 8 +- server/message_cache_test.go | 21 ++-- server/server.go | 9 +- server/server_account.go | 2 +- server/server_account_test.go | 2 +- server/server_payments.go | 4 +- server/server_payments_test.go | 4 +- server/server_test.go | 4 +- server/visitor.go | 2 +- user/manager.go | 191 ++++++++++++++++++++------------- user/manager_test.go | 10 +- user/types.go | 2 + util/util.go | 9 +- 13 files changed, 160 insertions(+), 108 deletions(-) diff --git a/server/message_cache.go b/server/message_cache.go index 8788cf99..b9723bde 100644 --- a/server/message_cache.go +++ b/server/message_cache.go @@ -98,8 +98,8 @@ const ( updateAttachmentDeleted = `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?` selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0` - selectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?` - selectAttachmentsSizeByUserQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?` + selectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?` + selectAttachmentsSizeByUserIDQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?` ) // Schema management queries @@ -563,8 +563,8 @@ func (c *messageCache) AttachmentBytesUsedBySender(sender string) (int64, error) return c.readAttachmentBytesUsed(rows) } -func (c *messageCache) AttachmentBytesUsedByUser(user string) (int64, error) { - rows, err := c.db.Query(selectAttachmentsSizeByUserQuery, user, time.Now().Unix()) +func (c *messageCache) AttachmentBytesUsedByUser(userID string) (int64, error) { + rows, err := c.db.Query(selectAttachmentsSizeByUserIDQuery, userID, time.Now().Unix()) if err != nil { return 0, err } diff --git a/server/message_cache_test.go b/server/message_cache_test.go index 2b838f25..79b7fc54 100644 --- a/server/message_cache_test.go +++ b/server/message_cache_test.go @@ -12,10 +12,6 @@ import ( "github.com/stretchr/testify/require" ) -var ( - exampleIP1234 = netip.MustParseAddr("1.2.3.4") -) - func TestSqliteCache_Messages(t *testing.T) { testCacheMessages(t, newSqliteTestCache(t)) } @@ -294,10 +290,10 @@ func TestMemCache_Attachments(t *testing.T) { } func testCacheAttachments(t *testing.T, c *messageCache) { - expires1 := time.Now().Add(-4 * time.Hour).Unix() + expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired m := newDefaultMessage("mytopic", "flower for you") m.ID = "m1" - m.Sender = exampleIP1234 + m.Sender = netip.MustParseAddr("1.2.3.4") m.Attachment = &attachment{ Name: "flower.jpg", Type: "image/jpeg", @@ -310,7 +306,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) { expires2 := time.Now().Add(2 * time.Hour).Unix() // Future m = newDefaultMessage("mytopic", "sending you a car") m.ID = "m2" - m.Sender = exampleIP1234 + m.Sender = netip.MustParseAddr("1.2.3.4") m.Attachment = &attachment{ Name: "car.jpg", Type: "image/jpeg", @@ -323,7 +319,8 @@ func testCacheAttachments(t *testing.T, c *messageCache) { expires3 := time.Now().Add(1 * time.Hour).Unix() // Future m = newDefaultMessage("another-topic", "sending you another car") m.ID = "m3" - m.Sender = exampleIP1234 + m.User = "u_BAsbaAa" + m.Sender = netip.MustParseAddr("5.6.7.8") m.Attachment = &attachment{ Name: "another-car.jpg", Type: "image/jpeg", @@ -355,11 +352,15 @@ func testCacheAttachments(t *testing.T, c *messageCache) { size, err := c.AttachmentBytesUsedBySender("1.2.3.4") require.Nil(t, err) - require.Equal(t, int64(30000), size) + require.Equal(t, int64(10000), size) size, err = c.AttachmentBytesUsedBySender("5.6.7.8") require.Nil(t, err) - require.Equal(t, int64(0), size) + require.Equal(t, int64(0), size) // Accounted to the user, not the IP! + + size, err = c.AttachmentBytesUsedByUser("u_BAsbaAa") + require.Nil(t, err) + require.Equal(t, int64(20000), size) } func TestSqliteCache_Attachments_Expired(t *testing.T) { diff --git a/server/server.go b/server/server.go index 12cff2b1..a098fc17 100644 --- a/server/server.go +++ b/server/server.go @@ -38,12 +38,13 @@ import ( TODO -- -- Security: Account re-creation leads to terrible behavior. Use user ID instead of user name for (a) visitor map, (b) messages.user column, (c) Stripe checkout session - Reservation: Kill existing subscribers when topic is reserved (deadcade) - Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben) - Reservation (UI): Ask for confirmation when removing reservation (deadcade) - Logging: Add detailed logging with username/customerID for all Stripe events (phil) - Rate limiting: Sensitive endpoints (account/login/change-password/...) +- Stripe webhook: Do not respond wih error if user does not exist (after account deletion) +- Stripe: Add metadata to customer races: - v.user --> see publishSyncEventAsync() test @@ -581,7 +582,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes m = newPollRequestMessage(t.ID, m.PollID) } if v.user != nil { - m.User = v.user.Name + m.User = v.user.ID } m.Expires = time.Now().Add(v.Limits().MessagesExpiryDuration).Unix() if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil { @@ -859,6 +860,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, if m.Time > attachmentExpiry { return errHTTPBadRequestAttachmentsExpiryBeforeDelivery } + fmt.Printf("v = %#v\nlimits = %#v\nstats = %#v\n", v, vinfo.Limits, vinfo.Stats) contentLengthStr := r.Header.Get("Content-Length") if contentLengthStr != "" { // Early "do-not-trust" check, hard limit see below contentLength, err := strconv.ParseInt(contentLengthStr, 10, 64) @@ -885,6 +887,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, util.NewFixedLimiter(vinfo.Limits.AttachmentFileSizeLimit), util.NewFixedLimiter(vinfo.Stats.AttachmentTotalSizeRemaining), } + fmt.Printf("limiters = %#v\nv = %#v\n", limiters, v) m.Attachment.Size, err = s.fileCache.Write(m.ID, body, limiters...) if err == util.ErrLimitReached { return errHTTPEntityTooLargeAttachment @@ -1657,7 +1660,7 @@ func (s *Server) visitorFromIP(ip netip.Addr) *visitor { } func (s *Server) visitorFromUser(user *user.User, ip netip.Addr) *visitor { - return s.visitorFromID(fmt.Sprintf("user:%s", user.Name), ip, user) + return s.visitorFromID(fmt.Sprintf("user:%s", user.ID), ip, user) } func (s *Server) writeJSON(w http.ResponseWriter, v any) error { diff --git a/server/server_account.go b/server/server_account.go index 755dcf75..1b243836 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -337,7 +337,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ return errHTTPTooManyRequestsLimitReservations } } - if err := s.userManager.ReserveAccess(v.user.Name, req.Topic, everyone); err != nil { + if err := s.userManager.AddReservation(v.user.Name, req.Topic, everyone); err != nil { return err } return s.writeJSON(w, newSuccessResponse()) diff --git a/server/server_account_test.go b/server/server_account_test.go index f1615ebf..4e4e452e 100644 --- a/server/server_account_test.go +++ b/server/server_account_test.go @@ -212,7 +212,7 @@ func TestAccount_ChangePassword(t *testing.T) { s := newTestServer(t, newTestConfigWithAuthFile(t)) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) - rr := request(t, s, "POST", "/v1/account/password", `{"password": "new password"}`, map[string]string{ + rr := request(t, s, "POST", "/v1/account/password", `{"password": "phil", "new_password": "new password"}`, map[string]string{ "Authorization": util.BasicAuth("phil", "phil"), }) require.Equal(t, 200, rr.Code) diff --git a/server/server_payments.go b/server/server_payments.go index c7ece4ef..40b961a3 100644 --- a/server/server_payments.go +++ b/server/server_payments.go @@ -128,7 +128,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r successURL := s.config.BaseURL + apiAccountBillingSubscriptionCheckoutSuccessTemplate params := &stripe.CheckoutSessionParams{ Customer: stripeCustomerID, // A user may have previously deleted their subscription - ClientReferenceID: &v.user.Name, + ClientReferenceID: &v.user.ID, SuccessURL: &successURL, Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)), AllowPromotionCodes: stripe.Bool(true), @@ -178,7 +178,7 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr if err != nil { return err } - u, err := s.userManager.User(sess.ClientReferenceID) + u, err := s.userManager.UserByID(sess.ClientReferenceID) if err != nil { return err } diff --git a/server/server_payments_test.go b/server/server_payments_test.go index 634109cb..4ee3d0e6 100644 --- a/server/server_payments_test.go +++ b/server/server_payments_test.go @@ -176,8 +176,8 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active( })) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test")) require.Nil(t, s.userManager.ChangeTier("phil", "pro")) - require.Nil(t, s.userManager.ReserveAccess("phil", "atopic", user.PermissionDenyAll)) - require.Nil(t, s.userManager.ReserveAccess("phil", "ztopic", user.PermissionDenyAll)) + require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll)) + require.Nil(t, s.userManager.AddReservation("phil", "ztopic", user.PermissionDenyAll)) // Add billing details u, err := s.userManager.User("phil") diff --git a/server/server_test.go b/server/server_test.go index 4d32f409..0986380c 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -830,7 +830,7 @@ func TestServer_PublishTooRequests_Defaults_ExemptHosts(t *testing.T) { func TestServer_PublishTooRequests_ShortReplenish(t *testing.T) { c := newTestConfig(t) c.VisitorRequestLimitBurst = 60 - c.VisitorRequestLimitReplenish = 500 * time.Millisecond + c.VisitorRequestLimitReplenish = time.Second s := newTestServer(t, c) for i := 0; i < 60; i++ { response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil) @@ -839,7 +839,7 @@ func TestServer_PublishTooRequests_ShortReplenish(t *testing.T) { response := request(t, s, "PUT", "/mytopic", "message", nil) require.Equal(t, 429, response.Code) - time.Sleep(520 * time.Millisecond) + time.Sleep(1020 * time.Millisecond) response = request(t, s, "PUT", "/mytopic", "message", nil) require.Equal(t, 200, response.Code) } diff --git a/server/visitor.go b/server/visitor.go index c752de8e..b23f66ef 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -241,7 +241,7 @@ func (v *visitor) Info() (*visitorInfo, error) { var attachmentsBytesUsed int64 var err error if v.user != nil { - attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedByUser(v.user.Name) + attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedByUser(v.user.ID) } else { attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedBySender(v.ip.String()) } diff --git a/user/manager.go b/user/manager.go index 652017e7..5c3baf7c 100644 --- a/user/manager.go +++ b/user/manager.go @@ -16,13 +16,19 @@ import ( ) const ( - bcryptCost = 10 - intentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match bcryptCost - userStatsQueueWriterInterval = 33 * time.Second - tokenLength = 32 - tokenExpiryDuration = 72 * time.Hour // Extend tokens by this much - syncTopicLength = 16 - tokenMaxCount = 10 // Only keep this many tokens in the table per user + tierIDPrefix = "ti_" + tierIDLength = 8 + syncTopicPrefix = "st_" + syncTopicLength = 16 + userIDPrefix = "u_" + userIDLength = 12 + userPasswordBcryptCost = 10 + userAuthIntentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match userPasswordBcryptCost + userStatsQueueWriterInterval = 33 * time.Second + tokenPrefix = "tk_" + tokenLength = 32 + tokenMaxCount = 10 // Only keep this many tokens in the table per user + tokenExpiryDuration = 72 * time.Hour // Extend tokens by this much ) var ( @@ -35,7 +41,7 @@ var ( const ( createTablesQueriesNoTx = ` CREATE TABLE IF NOT EXISTS tier ( - id INTEGER PRIMARY KEY AUTOINCREMENT, + id TEXT PRIMARY KEY, code TEXT NOT NULL, name TEXT NOT NULL, messages_limit INT NOT NULL, @@ -50,7 +56,7 @@ const ( CREATE UNIQUE INDEX idx_tier_code ON tier (code); CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id); CREATE TABLE IF NOT EXISTS user ( - id INTEGER PRIMARY KEY AUTOINCREMENT, + id TEXT PRIMARY KEY, tier_id INT, user TEXT NOT NULL, pass TEXT NOT NULL, @@ -72,7 +78,7 @@ const ( CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id); CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id); CREATE TABLE IF NOT EXISTS user_access ( - user_id INT NOT NULL, + user_id TEXT NOT NULL, topic TEXT NOT NULL, read INT NOT NULL, write INT NOT NULL, @@ -82,7 +88,7 @@ const ( FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE ); CREATE TABLE IF NOT EXISTS user_token ( - user_id INT NOT NULL, + user_id TEXT NOT NULL, token TEXT NOT NULL, expires INT NOT NULL, PRIMARY KEY (user_id, token), @@ -93,7 +99,7 @@ const ( version INT NOT NULL ); INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at) - VALUES (1, '*', '', 'anonymous', '', 'system', UNIXEPOCH()) + VALUES ('u_everyone', '*', '', 'anonymous', '', 'system', UNIXEPOCH()) ON CONFLICT (id) DO NOTHING; ` createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;` @@ -101,21 +107,27 @@ const ( PRAGMA foreign_keys = ON; ` + selectUserByIDQuery = ` + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id + FROM user u + LEFT JOIN tier t on t.id = u.tier_id + WHERE u.id = ? + ` selectUserByNameQuery = ` - SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id FROM user u LEFT JOIN tier t on t.id = u.tier_id WHERE user = ? ` selectUserByTokenQuery = ` - SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id FROM user u JOIN user_token t on u.id = t.user_id LEFT JOIN tier t on t.id = u.tier_id WHERE t.token = ? AND t.expires >= ? ` selectUserByStripeCustomerIDQuery = ` - SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id FROM user u LEFT JOIN tier t on t.id = u.tier_id WHERE u.stripe_customer_id = ? @@ -129,8 +141,8 @@ const ( ` insertUserQuery = ` - INSERT INTO user (user, pass, role, sync_topic, created_by, created_at) - VALUES (?, ?, ?, ?, ?, ?) + INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?) ` selectUsernamesQuery = ` SELECT user @@ -145,7 +157,7 @@ const ( updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?` updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?` updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE user = ?` - updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ? WHERE user = ?` + updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ? WHERE id = ?` updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0` deleteUserQuery = `DELETE FROM user WHERE user = ?` @@ -199,8 +211,8 @@ const ( AND topic = ? ` - selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE (SELECT id FROM user WHERE user = ?)` - insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES ((SELECT id FROM user WHERE user = ?), ?, ?)` + selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?` + insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES (?, ?, ?)` updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?` deleteTokenQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?` deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires < ?` @@ -209,27 +221,27 @@ const ( WHERE (user_id, token) NOT IN ( SELECT user_id, token FROM user_token - WHERE user_id = (SELECT id FROM user WHERE user = ?) + WHERE user_id = ? ORDER BY expires DESC LIMIT ? ) ` insertTierQuery = ` - INSERT INTO tier (code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ` selectTiersQuery = ` - SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id + SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id FROM tier ` selectTierByCodeQuery = ` - SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id + SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id FROM tier WHERE code = ? ` selectTierByPriceIDQuery = ` - SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id + SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id FROM tier WHERE stripe_price_id = ? ` @@ -254,10 +266,12 @@ const ( migrate1To2RenameUserTableQueryNoTx = ` ALTER TABLE user RENAME TO user_old; ` + migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old` + migrate1To2InsertUserNoTx = ` + INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at) + SELECT ?, user, pass, role, ?, 'admin', UNIXEPOCH() FROM user_old WHERE user = ? + ` migrate1To2InsertFromOldTablesAndDropNoTx = ` - INSERT INTO user (user, pass, role, sync_topic, created_by, created_at) - SELECT user, pass, role, '', 'admin', UNIXEPOCH() FROM user_old; - INSERT INTO user_access (user_id, topic, read, write) SELECT u.id, a.topic, a.read, a.write FROM user u @@ -266,8 +280,7 @@ const ( DROP TABLE access; DROP TABLE user_old; ` - migrate1To2SelectAllUsersIDsNoTx = `SELECT id FROM user` - migrate1To2UpdateSyncTopicNoTx = `UPDATE user SET sync_topic = ? WHERE id = ?` + migrate1To2UpdateSyncTopicNoTx = `UPDATE user SET sync_topic = ? WHERE id = ?` ) // Manager is an implementation of Manager. It stores users and access control list @@ -317,7 +330,7 @@ func (a *Manager) Authenticate(username, password string) (*User, error) { user, err := a.User(username) if err != nil { log.Trace("authentication of user %s failed (1): %s", username, err.Error()) - bcrypt.CompareHashAndPassword([]byte(intentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks")) + bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks")) return nil, ErrUnauthenticated } if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil { @@ -345,16 +358,16 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) { // after a fixed duration unless ExtendToken is called. This function also prunes tokens for the // given user, if there are too many of them. func (a *Manager) CreateToken(user *User) (*Token, error) { - token, expires := util.RandomString(tokenLength), time.Now().Add(tokenExpiryDuration) + token, expires := util.RandomStringPrefix(tokenPrefix, tokenLength), time.Now().Add(tokenExpiryDuration) tx, err := a.db.Begin() if err != nil { return nil, err } defer tx.Rollback() - if _, err := tx.Exec(insertTokenQuery, user.Name, token, expires.Unix()); err != nil { + if _, err := tx.Exec(insertTokenQuery, user.ID, token, expires.Unix()); err != nil { return nil, err } - rows, err := tx.Query(selectTokenCountQuery, user.Name) + rows, err := tx.Query(selectTokenCountQuery, user.ID) if err != nil { return nil, err } @@ -369,7 +382,7 @@ func (a *Manager) CreateToken(user *User) (*Token, error) { if tokenCount >= tokenMaxCount { // This pruning logic is done in two queries for efficiency. The SELECT above is a lookup // on two indices, whereas the query below is a full table scan. - if _, err := tx.Exec(deleteExcessTokensQuery, user.Name, tokenMaxCount); err != nil { + if _, err := tx.Exec(deleteExcessTokensQuery, user.ID, tokenMaxCount); err != nil { return nil, err } } @@ -444,7 +457,7 @@ func (a *Manager) ResetStats() error { func (a *Manager) EnqueueStats(user *User) { a.mu.Lock() defer a.mu.Unlock() - a.statsQueue[user.Name] = user + a.statsQueue[user.ID] = user } func (a *Manager) userStatsQueueWriter(interval time.Duration) { @@ -472,9 +485,9 @@ func (a *Manager) writeUserStatsQueue() error { } defer tx.Rollback() log.Debug("User Manager: Writing user stats queue for %d user(s)", len(statsQueue)) - for username, u := range statsQueue { - log.Trace("User Manager: Updating stats for user %s: messages=%d, emails=%d", username, u.Stats.Messages, u.Stats.Emails) - if _, err := tx.Exec(updateUserStatsQuery, u.Stats.Messages, u.Stats.Emails, username); err != nil { + for userID, u := range statsQueue { + log.Trace("User Manager: Updating stats for user %s: messages=%d, emails=%d", userID, u.Stats.Messages, u.Stats.Emails) + if _, err := tx.Exec(updateUserStatsQuery, u.Stats.Messages, u.Stats.Emails, userID); err != nil { return err } } @@ -524,12 +537,13 @@ func (a *Manager) AddUser(username, password string, role Role, createdBy string if !AllowedUsername(username) || !AllowedRole(role) { return ErrInvalidArgument } - hash, err := bcrypt.GenerateFromPassword([]byte(password), bcryptCost) + hash, err := bcrypt.GenerateFromPassword([]byte(password), userPasswordBcryptCost) if err != nil { return err } - syncTopic, now := util.RandomString(syncTopicLength), time.Now().Unix() - if _, err = a.db.Exec(insertUserQuery, username, hash, role, syncTopic, createdBy, now); err != nil { + userID := util.RandomStringPrefix(userIDPrefix, userIDLength) + syncTopic, now := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength), time.Now().Unix() + if _, err = a.db.Exec(insertUserQuery, userID, username, hash, role, syncTopic, createdBy, now); err != nil { return err } return nil @@ -587,6 +601,15 @@ func (a *Manager) User(username string) (*User, error) { return a.readUser(rows) } +// UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise +func (a *Manager) UserByID(id string) (*User, error) { + rows, err := a.db.Query(selectUserByIDQuery, id) + if err != nil { + return nil, err + } + return a.readUser(rows) +} + // UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise. func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) { rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID) @@ -606,19 +629,20 @@ func (a *Manager) userByToken(token string) (*User, error) { func (a *Manager) readUser(rows *sql.Rows) (*User, error) { defer rows.Close() - var username, hash, role, prefs, syncTopic string + var id, username, hash, role, prefs, syncTopic string var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString var messages, emails int64 var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt sql.NullInt64 if !rows.Next() { return nil, ErrUserNotFound } - if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil { + if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err } user := &User{ + ID: id, Name: username, Hash: hash, Role: Role(role), @@ -744,7 +768,7 @@ func (a *Manager) ReservationsCount(username string) (int64, error) { // ChangePassword changes a user's password func (a *Manager) ChangePassword(username, password string) error { - hash, err := bcrypt.GenerateFromPassword([]byte(password), bcryptCost) + hash, err := bcrypt.GenerateFromPassword([]byte(password), userPasswordBcryptCost) if err != nil { return err } @@ -818,6 +842,7 @@ func (a *Manager) checkReservationsLimit(username string, reservationsLimit int6 // CheckAllowAccess tests if a user may create an access control entry for the given topic. // If there are any ACL entries that are not owned by the user, an error is returned. +// FIXME is this the same as HasReservation? func (a *Manager) CheckAllowAccess(username string, topic string) error { if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) { return ErrInvalidArgument @@ -856,24 +881,6 @@ func (a *Manager) AllowAccess(username string, topicPattern string, permission P return nil } -func (a *Manager) ReserveAccess(username string, topic string, everyone Permission) error { - if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) { - return ErrInvalidArgument - } - tx, err := a.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if _, err := tx.Exec(upsertUserAccessQuery, username, topic, true, true, username, username); err != nil { - return err - } - if _, err := tx.Exec(upsertUserAccessQuery, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, username); err != nil { - return err - } - return tx.Commit() -} - // ResetAccess removes an access control list entry for a specific username/topic, or (if topic is // empty) for an entire user. The parameter topicPattern may include wildcards (*). func (a *Manager) ResetAccess(username string, topicPattern string) error { @@ -893,6 +900,29 @@ func (a *Manager) ResetAccess(username string, topicPattern string) error { return err } +// AddReservation creates two access control entries for the given topic: one with full read/write access for the +// given user, and one for Everyone with the permission passed as everyone. The user also owns the entries, and +// can modify or delete them. +func (a *Manager) AddReservation(username string, topic string, everyone Permission) error { + if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) { + return ErrInvalidArgument + } + tx, err := a.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(upsertUserAccessQuery, username, topic, true, true, username, username); err != nil { + return err + } + if _, err := tx.Exec(upsertUserAccessQuery, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, username); err != nil { + return err + } + return tx.Commit() +} + +// RemoveReservations deletes the access control entries associated with the given username/topic, as +// well as all entries with Everyone/topic. This is the counterpart for AddReservation. func (a *Manager) RemoveReservations(username string, topics ...string) error { if !AllowedUsername(username) || username == Everyone || len(topics) == 0 { return ErrInvalidArgument @@ -925,7 +955,8 @@ func (a *Manager) DefaultAccess() Permission { // CreateTier creates a new tier in the database func (a *Manager) CreateTier(tier *Tier) error { - if _, err := a.db.Exec(insertTierQuery, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.StripePriceID); err != nil { + tierID := util.RandomStringPrefix(tierIDPrefix, tierIDLength) + if _, err := a.db.Exec(insertTierQuery, tierID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.StripePriceID); err != nil { return err } return nil @@ -980,19 +1011,20 @@ func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) { } func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) { - var code, name string + var id, code, name string var stripePriceID sql.NullString var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64 if !rows.Next() { return nil, ErrTierNotFound } - if err := rows.Scan(&code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil { + if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err } // When changed, note readUser() as well return &Tier{ + ID: id, Code: code, Name: name, Paid: stripePriceID.Valid, // If there is a price, it's a paid tier @@ -1069,36 +1101,41 @@ func migrateFrom1(db *sql.DB) error { return err } defer tx.Rollback() + // Rename user -> user_old, and create new tables if _, err := tx.Exec(migrate1To2RenameUserTableQueryNoTx); err != nil { return err } if _, err := tx.Exec(createTablesQueriesNoTx); err != nil { return err } - if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil { - return err - } - rows, err := tx.Query(migrate1To2SelectAllUsersIDsNoTx) + // Insert users from user_old into new user table, with ID and sync_topic + rows, err := tx.Query(migrate1To2SelectAllOldUsernamesNoTx) if err != nil { return err } defer rows.Close() - syncTopics := make(map[int]string) + usernames := make([]string, 0) for rows.Next() { - var userID int - if err := rows.Scan(&userID); err != nil { + var username string + if err := rows.Scan(&username); err != nil { return err } - syncTopics[userID] = util.RandomString(syncTopicLength) + usernames = append(usernames, username) } if err := rows.Close(); err != nil { return err } - for userID, syncTopic := range syncTopics { - if _, err := tx.Exec(migrate1To2UpdateSyncTopicNoTx, syncTopic, userID); err != nil { + for _, username := range usernames { + userID := util.RandomStringPrefix(userIDPrefix, userIDLength) + syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength) + if _, err := tx.Exec(migrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil { return err } } + // Migrate old "access" table to "user_access" and drop "access" and "user_old" + if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil { + return err + } if _, err := tx.Exec(updateSchemaVersion, 2); err != nil { return err } diff --git a/user/manager_test.go b/user/manager_test.go index fdbe5e0e..35d8cac8 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -259,8 +259,8 @@ func TestManager_ChangeRole(t *testing.T) { func TestManager_Reservations(t *testing.T) { a := newTestManager(t, PermissionDenyAll) require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) - require.Nil(t, a.ReserveAccess("ben", "ztopic", PermissionDenyAll)) - require.Nil(t, a.ReserveAccess("ben", "readme", PermissionRead)) + require.Nil(t, a.AddReservation("ben", "ztopic", PermissionDenyAll)) + require.Nil(t, a.AddReservation("ben", "readme", PermissionRead)) require.Nil(t, a.AllowAccess("ben", "something-else", PermissionRead)) reservations, err := a.Reservations("ben") @@ -294,7 +294,7 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) { })) require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test")) require.Nil(t, a.ChangeTier("ben", "pro")) - require.Nil(t, a.ReserveAccess("ben", "mytopic", PermissionDenyAll)) + require.Nil(t, a.AddReservation("ben", "mytopic", PermissionDenyAll)) ben, err := a.User("ben") require.Nil(t, err) @@ -626,11 +626,14 @@ func TestSqliteCache_Migration_From1(t *testing.T) { everyoneGrants, err := a.Grants(Everyone) require.Nil(t, err) + require.True(t, strings.HasPrefix(phil.ID, "u_")) require.Equal(t, "phil", phil.Name) require.Equal(t, RoleAdmin, phil.Role) require.Equal(t, syncTopicLength, len(phil.SyncTopic)) require.Equal(t, 0, len(philGrants)) + require.True(t, strings.HasPrefix(ben.ID, "u_")) + require.NotEqual(t, phil.ID, ben.ID) require.Equal(t, "ben", ben.Name) require.Equal(t, RoleUser, ben.Role) require.Equal(t, syncTopicLength, len(ben.SyncTopic)) @@ -641,6 +644,7 @@ func TestSqliteCache_Migration_From1(t *testing.T) { require.Equal(t, "secret", benGrants[1].TopicPattern) require.Equal(t, PermissionRead, benGrants[1].Allow) + require.Equal(t, "u_everyone", everyone.ID) require.Equal(t, Everyone, everyone.Name) require.Equal(t, RoleAnonymous, everyone.Role) require.Equal(t, 1, len(everyoneGrants)) diff --git a/user/types.go b/user/types.go index 5e95ad56..22137539 100644 --- a/user/types.go +++ b/user/types.go @@ -10,6 +10,7 @@ import ( // User is a struct that represents a user type User struct { + ID string Name string Hash string // password hash (bcrypt) Token string // Only set if token was used to log in @@ -50,6 +51,7 @@ type Prefs struct { // Tier represents a user's account type, including its account limits type Tier struct { + ID string Code string Name string Paid bool diff --git a/util/util.go b/util/util.go index ac1ed7b1..15a922d5 100644 --- a/util/util.go +++ b/util/util.go @@ -107,13 +107,18 @@ func LastString(s []string, def string) string { // RandomString returns a random string with a given length func RandomString(length int) string { + return RandomStringPrefix("", length) +} + +// RandomStringPrefix returns a random string with a given length, with a prefix +func RandomStringPrefix(prefix string, length int) string { randomMutex.Lock() // Who would have thought that random.Intn() is not thread-safe?! defer randomMutex.Unlock() - b := make([]byte, length) + b := make([]byte, length-len(prefix)) for i := range b { b[i] = randomStringCharset[random.Intn(len(randomStringCharset))] } - return string(b) + return prefix + string(b) } // ValidRandomString returns true if the given string matches the format created by RandomString