Introduce text IDs for everything (esp user), to avoid security and accounting issues

This commit is contained in:
binwiederhier 2023-01-21 23:15:22 -05:00
parent 88abd8872d
commit 9c082a8331
13 changed files with 160 additions and 108 deletions

View File

@ -98,8 +98,8 @@ const (
updateAttachmentDeleted = `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?`
selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0`
selectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?`
selectAttachmentsSizeByUserQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`
selectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?`
selectAttachmentsSizeByUserIDQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`
)
// Schema management queries
@ -563,8 +563,8 @@ func (c *messageCache) AttachmentBytesUsedBySender(sender string) (int64, error)
return c.readAttachmentBytesUsed(rows)
}
func (c *messageCache) AttachmentBytesUsedByUser(user string) (int64, error) {
rows, err := c.db.Query(selectAttachmentsSizeByUserQuery, user, time.Now().Unix())
func (c *messageCache) AttachmentBytesUsedByUser(userID string) (int64, error) {
rows, err := c.db.Query(selectAttachmentsSizeByUserIDQuery, userID, time.Now().Unix())
if err != nil {
return 0, err
}

View File

@ -12,10 +12,6 @@ import (
"github.com/stretchr/testify/require"
)
var (
exampleIP1234 = netip.MustParseAddr("1.2.3.4")
)
func TestSqliteCache_Messages(t *testing.T) {
testCacheMessages(t, newSqliteTestCache(t))
}
@ -294,10 +290,10 @@ func TestMemCache_Attachments(t *testing.T) {
}
func testCacheAttachments(t *testing.T, c *messageCache) {
expires1 := time.Now().Add(-4 * time.Hour).Unix()
expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired
m := newDefaultMessage("mytopic", "flower for you")
m.ID = "m1"
m.Sender = exampleIP1234
m.Sender = netip.MustParseAddr("1.2.3.4")
m.Attachment = &attachment{
Name: "flower.jpg",
Type: "image/jpeg",
@ -310,7 +306,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
expires2 := time.Now().Add(2 * time.Hour).Unix() // Future
m = newDefaultMessage("mytopic", "sending you a car")
m.ID = "m2"
m.Sender = exampleIP1234
m.Sender = netip.MustParseAddr("1.2.3.4")
m.Attachment = &attachment{
Name: "car.jpg",
Type: "image/jpeg",
@ -323,7 +319,8 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
expires3 := time.Now().Add(1 * time.Hour).Unix() // Future
m = newDefaultMessage("another-topic", "sending you another car")
m.ID = "m3"
m.Sender = exampleIP1234
m.User = "u_BAsbaAa"
m.Sender = netip.MustParseAddr("5.6.7.8")
m.Attachment = &attachment{
Name: "another-car.jpg",
Type: "image/jpeg",
@ -355,11 +352,15 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
size, err := c.AttachmentBytesUsedBySender("1.2.3.4")
require.Nil(t, err)
require.Equal(t, int64(30000), size)
require.Equal(t, int64(10000), size)
size, err = c.AttachmentBytesUsedBySender("5.6.7.8")
require.Nil(t, err)
require.Equal(t, int64(0), size)
require.Equal(t, int64(0), size) // Accounted to the user, not the IP!
size, err = c.AttachmentBytesUsedByUser("u_BAsbaAa")
require.Nil(t, err)
require.Equal(t, int64(20000), size)
}
func TestSqliteCache_Attachments_Expired(t *testing.T) {

View File

@ -38,12 +38,13 @@ import (
TODO
--
- Security: Account re-creation leads to terrible behavior. Use user ID instead of user name for (a) visitor map, (b) messages.user column, (c) Stripe checkout session
- Reservation: Kill existing subscribers when topic is reserved (deadcade)
- Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
- Reservation (UI): Ask for confirmation when removing reservation (deadcade)
- Logging: Add detailed logging with username/customerID for all Stripe events (phil)
- Rate limiting: Sensitive endpoints (account/login/change-password/...)
- Stripe webhook: Do not respond wih error if user does not exist (after account deletion)
- Stripe: Add metadata to customer
races:
- v.user --> see publishSyncEventAsync() test
@ -581,7 +582,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
m = newPollRequestMessage(t.ID, m.PollID)
}
if v.user != nil {
m.User = v.user.Name
m.User = v.user.ID
}
m.Expires = time.Now().Add(v.Limits().MessagesExpiryDuration).Unix()
if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
@ -859,6 +860,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
if m.Time > attachmentExpiry {
return errHTTPBadRequestAttachmentsExpiryBeforeDelivery
}
fmt.Printf("v = %#v\nlimits = %#v\nstats = %#v\n", v, vinfo.Limits, vinfo.Stats)
contentLengthStr := r.Header.Get("Content-Length")
if contentLengthStr != "" { // Early "do-not-trust" check, hard limit see below
contentLength, err := strconv.ParseInt(contentLengthStr, 10, 64)
@ -885,6 +887,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
util.NewFixedLimiter(vinfo.Limits.AttachmentFileSizeLimit),
util.NewFixedLimiter(vinfo.Stats.AttachmentTotalSizeRemaining),
}
fmt.Printf("limiters = %#v\nv = %#v\n", limiters, v)
m.Attachment.Size, err = s.fileCache.Write(m.ID, body, limiters...)
if err == util.ErrLimitReached {
return errHTTPEntityTooLargeAttachment
@ -1657,7 +1660,7 @@ func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
}
func (s *Server) visitorFromUser(user *user.User, ip netip.Addr) *visitor {
return s.visitorFromID(fmt.Sprintf("user:%s", user.Name), ip, user)
return s.visitorFromID(fmt.Sprintf("user:%s", user.ID), ip, user)
}
func (s *Server) writeJSON(w http.ResponseWriter, v any) error {

View File

@ -337,7 +337,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
return errHTTPTooManyRequestsLimitReservations
}
}
if err := s.userManager.ReserveAccess(v.user.Name, req.Topic, everyone); err != nil {
if err := s.userManager.AddReservation(v.user.Name, req.Topic, everyone); err != nil {
return err
}
return s.writeJSON(w, newSuccessResponse())

View File

@ -212,7 +212,7 @@ func TestAccount_ChangePassword(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
rr := request(t, s, "POST", "/v1/account/password", `{"password": "new password"}`, map[string]string{
rr := request(t, s, "POST", "/v1/account/password", `{"password": "phil", "new_password": "new password"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)

View File

@ -128,7 +128,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
successURL := s.config.BaseURL + apiAccountBillingSubscriptionCheckoutSuccessTemplate
params := &stripe.CheckoutSessionParams{
Customer: stripeCustomerID, // A user may have previously deleted their subscription
ClientReferenceID: &v.user.Name,
ClientReferenceID: &v.user.ID,
SuccessURL: &successURL,
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
AllowPromotionCodes: stripe.Bool(true),
@ -178,7 +178,7 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
if err != nil {
return err
}
u, err := s.userManager.User(sess.ClientReferenceID)
u, err := s.userManager.UserByID(sess.ClientReferenceID)
if err != nil {
return err
}

View File

@ -176,8 +176,8 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
require.Nil(t, s.userManager.ReserveAccess("phil", "atopic", user.PermissionDenyAll))
require.Nil(t, s.userManager.ReserveAccess("phil", "ztopic", user.PermissionDenyAll))
require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
require.Nil(t, s.userManager.AddReservation("phil", "ztopic", user.PermissionDenyAll))
// Add billing details
u, err := s.userManager.User("phil")

View File

@ -830,7 +830,7 @@ func TestServer_PublishTooRequests_Defaults_ExemptHosts(t *testing.T) {
func TestServer_PublishTooRequests_ShortReplenish(t *testing.T) {
c := newTestConfig(t)
c.VisitorRequestLimitBurst = 60
c.VisitorRequestLimitReplenish = 500 * time.Millisecond
c.VisitorRequestLimitReplenish = time.Second
s := newTestServer(t, c)
for i := 0; i < 60; i++ {
response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil)
@ -839,7 +839,7 @@ func TestServer_PublishTooRequests_ShortReplenish(t *testing.T) {
response := request(t, s, "PUT", "/mytopic", "message", nil)
require.Equal(t, 429, response.Code)
time.Sleep(520 * time.Millisecond)
time.Sleep(1020 * time.Millisecond)
response = request(t, s, "PUT", "/mytopic", "message", nil)
require.Equal(t, 200, response.Code)
}

View File

@ -241,7 +241,7 @@ func (v *visitor) Info() (*visitorInfo, error) {
var attachmentsBytesUsed int64
var err error
if v.user != nil {
attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedByUser(v.user.Name)
attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedByUser(v.user.ID)
} else {
attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedBySender(v.ip.String())
}

View File

@ -16,13 +16,19 @@ import (
)
const (
bcryptCost = 10
intentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match bcryptCost
userStatsQueueWriterInterval = 33 * time.Second
tokenLength = 32
tokenExpiryDuration = 72 * time.Hour // Extend tokens by this much
syncTopicLength = 16
tokenMaxCount = 10 // Only keep this many tokens in the table per user
tierIDPrefix = "ti_"
tierIDLength = 8
syncTopicPrefix = "st_"
syncTopicLength = 16
userIDPrefix = "u_"
userIDLength = 12
userPasswordBcryptCost = 10
userAuthIntentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match userPasswordBcryptCost
userStatsQueueWriterInterval = 33 * time.Second
tokenPrefix = "tk_"
tokenLength = 32
tokenMaxCount = 10 // Only keep this many tokens in the table per user
tokenExpiryDuration = 72 * time.Hour // Extend tokens by this much
)
var (
@ -35,7 +41,7 @@ var (
const (
createTablesQueriesNoTx = `
CREATE TABLE IF NOT EXISTS tier (
id INTEGER PRIMARY KEY AUTOINCREMENT,
id TEXT PRIMARY KEY,
code TEXT NOT NULL,
name TEXT NOT NULL,
messages_limit INT NOT NULL,
@ -50,7 +56,7 @@ const (
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
CREATE TABLE IF NOT EXISTS user (
id INTEGER PRIMARY KEY AUTOINCREMENT,
id TEXT PRIMARY KEY,
tier_id INT,
user TEXT NOT NULL,
pass TEXT NOT NULL,
@ -72,7 +78,7 @@ const (
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
CREATE TABLE IF NOT EXISTS user_access (
user_id INT NOT NULL,
user_id TEXT NOT NULL,
topic TEXT NOT NULL,
read INT NOT NULL,
write INT NOT NULL,
@ -82,7 +88,7 @@ const (
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS user_token (
user_id INT NOT NULL,
user_id TEXT NOT NULL,
token TEXT NOT NULL,
expires INT NOT NULL,
PRIMARY KEY (user_id, token),
@ -93,7 +99,7 @@ const (
version INT NOT NULL
);
INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at)
VALUES (1, '*', '', 'anonymous', '', 'system', UNIXEPOCH())
VALUES ('u_everyone', '*', '', 'anonymous', '', 'system', UNIXEPOCH())
ON CONFLICT (id) DO NOTHING;
`
createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;`
@ -101,21 +107,27 @@ const (
PRAGMA foreign_keys = ON;
`
selectUserByIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.id = ?
`
selectUserByNameQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE user = ?
`
selectUserByTokenQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
FROM user u
JOIN user_token t on u.id = t.user_id
LEFT JOIN tier t on t.id = u.tier_id
WHERE t.token = ? AND t.expires >= ?
`
selectUserByStripeCustomerIDQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.stripe_customer_id = ?
@ -129,8 +141,8 @@ const (
`
insertUserQuery = `
INSERT INTO user (user, pass, role, sync_topic, created_by, created_at)
VALUES (?, ?, ?, ?, ?, ?)
INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
`
selectUsernamesQuery = `
SELECT user
@ -145,7 +157,7 @@ const (
updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE user = ?`
updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ? WHERE user = ?`
updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ? WHERE id = ?`
updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0`
deleteUserQuery = `DELETE FROM user WHERE user = ?`
@ -199,8 +211,8 @@ const (
AND topic = ?
`
selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE (SELECT id FROM user WHERE user = ?)`
insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES ((SELECT id FROM user WHERE user = ?), ?, ?)`
selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES (?, ?, ?)`
updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
deleteTokenQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires < ?`
@ -209,27 +221,27 @@ const (
WHERE (user_id, token) NOT IN (
SELECT user_id, token
FROM user_token
WHERE user_id = (SELECT id FROM user WHERE user = ?)
WHERE user_id = ?
ORDER BY expires DESC
LIMIT ?
)
`
insertTierQuery = `
INSERT INTO tier (code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
selectTiersQuery = `
SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
FROM tier
`
selectTierByCodeQuery = `
SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
FROM tier
WHERE code = ?
`
selectTierByPriceIDQuery = `
SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
FROM tier
WHERE stripe_price_id = ?
`
@ -254,10 +266,12 @@ const (
migrate1To2RenameUserTableQueryNoTx = `
ALTER TABLE user RENAME TO user_old;
`
migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
migrate1To2InsertUserNoTx = `
INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at)
SELECT ?, user, pass, role, ?, 'admin', UNIXEPOCH() FROM user_old WHERE user = ?
`
migrate1To2InsertFromOldTablesAndDropNoTx = `
INSERT INTO user (user, pass, role, sync_topic, created_by, created_at)
SELECT user, pass, role, '', 'admin', UNIXEPOCH() FROM user_old;
INSERT INTO user_access (user_id, topic, read, write)
SELECT u.id, a.topic, a.read, a.write
FROM user u
@ -266,8 +280,7 @@ const (
DROP TABLE access;
DROP TABLE user_old;
`
migrate1To2SelectAllUsersIDsNoTx = `SELECT id FROM user`
migrate1To2UpdateSyncTopicNoTx = `UPDATE user SET sync_topic = ? WHERE id = ?`
migrate1To2UpdateSyncTopicNoTx = `UPDATE user SET sync_topic = ? WHERE id = ?`
)
// Manager is an implementation of Manager. It stores users and access control list
@ -317,7 +330,7 @@ func (a *Manager) Authenticate(username, password string) (*User, error) {
user, err := a.User(username)
if err != nil {
log.Trace("authentication of user %s failed (1): %s", username, err.Error())
bcrypt.CompareHashAndPassword([]byte(intentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
return nil, ErrUnauthenticated
}
if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil {
@ -345,16 +358,16 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
// after a fixed duration unless ExtendToken is called. This function also prunes tokens for the
// given user, if there are too many of them.
func (a *Manager) CreateToken(user *User) (*Token, error) {
token, expires := util.RandomString(tokenLength), time.Now().Add(tokenExpiryDuration)
token, expires := util.RandomStringPrefix(tokenPrefix, tokenLength), time.Now().Add(tokenExpiryDuration)
tx, err := a.db.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
if _, err := tx.Exec(insertTokenQuery, user.Name, token, expires.Unix()); err != nil {
if _, err := tx.Exec(insertTokenQuery, user.ID, token, expires.Unix()); err != nil {
return nil, err
}
rows, err := tx.Query(selectTokenCountQuery, user.Name)
rows, err := tx.Query(selectTokenCountQuery, user.ID)
if err != nil {
return nil, err
}
@ -369,7 +382,7 @@ func (a *Manager) CreateToken(user *User) (*Token, error) {
if tokenCount >= tokenMaxCount {
// This pruning logic is done in two queries for efficiency. The SELECT above is a lookup
// on two indices, whereas the query below is a full table scan.
if _, err := tx.Exec(deleteExcessTokensQuery, user.Name, tokenMaxCount); err != nil {
if _, err := tx.Exec(deleteExcessTokensQuery, user.ID, tokenMaxCount); err != nil {
return nil, err
}
}
@ -444,7 +457,7 @@ func (a *Manager) ResetStats() error {
func (a *Manager) EnqueueStats(user *User) {
a.mu.Lock()
defer a.mu.Unlock()
a.statsQueue[user.Name] = user
a.statsQueue[user.ID] = user
}
func (a *Manager) userStatsQueueWriter(interval time.Duration) {
@ -472,9 +485,9 @@ func (a *Manager) writeUserStatsQueue() error {
}
defer tx.Rollback()
log.Debug("User Manager: Writing user stats queue for %d user(s)", len(statsQueue))
for username, u := range statsQueue {
log.Trace("User Manager: Updating stats for user %s: messages=%d, emails=%d", username, u.Stats.Messages, u.Stats.Emails)
if _, err := tx.Exec(updateUserStatsQuery, u.Stats.Messages, u.Stats.Emails, username); err != nil {
for userID, u := range statsQueue {
log.Trace("User Manager: Updating stats for user %s: messages=%d, emails=%d", userID, u.Stats.Messages, u.Stats.Emails)
if _, err := tx.Exec(updateUserStatsQuery, u.Stats.Messages, u.Stats.Emails, userID); err != nil {
return err
}
}
@ -524,12 +537,13 @@ func (a *Manager) AddUser(username, password string, role Role, createdBy string
if !AllowedUsername(username) || !AllowedRole(role) {
return ErrInvalidArgument
}
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcryptCost)
hash, err := bcrypt.GenerateFromPassword([]byte(password), userPasswordBcryptCost)
if err != nil {
return err
}
syncTopic, now := util.RandomString(syncTopicLength), time.Now().Unix()
if _, err = a.db.Exec(insertUserQuery, username, hash, role, syncTopic, createdBy, now); err != nil {
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
syncTopic, now := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength), time.Now().Unix()
if _, err = a.db.Exec(insertUserQuery, userID, username, hash, role, syncTopic, createdBy, now); err != nil {
return err
}
return nil
@ -587,6 +601,15 @@ func (a *Manager) User(username string) (*User, error) {
return a.readUser(rows)
}
// UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise
func (a *Manager) UserByID(id string) (*User, error) {
rows, err := a.db.Query(selectUserByIDQuery, id)
if err != nil {
return nil, err
}
return a.readUser(rows)
}
// UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise.
func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) {
rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID)
@ -606,19 +629,20 @@ func (a *Manager) userByToken(token string) (*User, error) {
func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
defer rows.Close()
var username, hash, role, prefs, syncTopic string
var id, username, hash, role, prefs, syncTopic string
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString
var messages, emails int64
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt sql.NullInt64
if !rows.Next() {
return nil, ErrUserNotFound
}
if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
user := &User{
ID: id,
Name: username,
Hash: hash,
Role: Role(role),
@ -744,7 +768,7 @@ func (a *Manager) ReservationsCount(username string) (int64, error) {
// ChangePassword changes a user's password
func (a *Manager) ChangePassword(username, password string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcryptCost)
hash, err := bcrypt.GenerateFromPassword([]byte(password), userPasswordBcryptCost)
if err != nil {
return err
}
@ -818,6 +842,7 @@ func (a *Manager) checkReservationsLimit(username string, reservationsLimit int6
// CheckAllowAccess tests if a user may create an access control entry for the given topic.
// If there are any ACL entries that are not owned by the user, an error is returned.
// FIXME is this the same as HasReservation?
func (a *Manager) CheckAllowAccess(username string, topic string) error {
if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) {
return ErrInvalidArgument
@ -856,24 +881,6 @@ func (a *Manager) AllowAccess(username string, topicPattern string, permission P
return nil
}
func (a *Manager) ReserveAccess(username string, topic string, everyone Permission) error {
if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
return ErrInvalidArgument
}
tx, err := a.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(upsertUserAccessQuery, username, topic, true, true, username, username); err != nil {
return err
}
if _, err := tx.Exec(upsertUserAccessQuery, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, username); err != nil {
return err
}
return tx.Commit()
}
// ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
// empty) for an entire user. The parameter topicPattern may include wildcards (*).
func (a *Manager) ResetAccess(username string, topicPattern string) error {
@ -893,6 +900,29 @@ func (a *Manager) ResetAccess(username string, topicPattern string) error {
return err
}
// AddReservation creates two access control entries for the given topic: one with full read/write access for the
// given user, and one for Everyone with the permission passed as everyone. The user also owns the entries, and
// can modify or delete them.
func (a *Manager) AddReservation(username string, topic string, everyone Permission) error {
if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
return ErrInvalidArgument
}
tx, err := a.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(upsertUserAccessQuery, username, topic, true, true, username, username); err != nil {
return err
}
if _, err := tx.Exec(upsertUserAccessQuery, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, username); err != nil {
return err
}
return tx.Commit()
}
// RemoveReservations deletes the access control entries associated with the given username/topic, as
// well as all entries with Everyone/topic. This is the counterpart for AddReservation.
func (a *Manager) RemoveReservations(username string, topics ...string) error {
if !AllowedUsername(username) || username == Everyone || len(topics) == 0 {
return ErrInvalidArgument
@ -925,7 +955,8 @@ func (a *Manager) DefaultAccess() Permission {
// CreateTier creates a new tier in the database
func (a *Manager) CreateTier(tier *Tier) error {
if _, err := a.db.Exec(insertTierQuery, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.StripePriceID); err != nil {
tierID := util.RandomStringPrefix(tierIDPrefix, tierIDLength)
if _, err := a.db.Exec(insertTierQuery, tierID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.StripePriceID); err != nil {
return err
}
return nil
@ -980,19 +1011,20 @@ func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
}
func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
var code, name string
var id, code, name string
var stripePriceID sql.NullString
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
if !rows.Next() {
return nil, ErrTierNotFound
}
if err := rows.Scan(&code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
// When changed, note readUser() as well
return &Tier{
ID: id,
Code: code,
Name: name,
Paid: stripePriceID.Valid, // If there is a price, it's a paid tier
@ -1069,36 +1101,41 @@ func migrateFrom1(db *sql.DB) error {
return err
}
defer tx.Rollback()
// Rename user -> user_old, and create new tables
if _, err := tx.Exec(migrate1To2RenameUserTableQueryNoTx); err != nil {
return err
}
if _, err := tx.Exec(createTablesQueriesNoTx); err != nil {
return err
}
if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
return err
}
rows, err := tx.Query(migrate1To2SelectAllUsersIDsNoTx)
// Insert users from user_old into new user table, with ID and sync_topic
rows, err := tx.Query(migrate1To2SelectAllOldUsernamesNoTx)
if err != nil {
return err
}
defer rows.Close()
syncTopics := make(map[int]string)
usernames := make([]string, 0)
for rows.Next() {
var userID int
if err := rows.Scan(&userID); err != nil {
var username string
if err := rows.Scan(&username); err != nil {
return err
}
syncTopics[userID] = util.RandomString(syncTopicLength)
usernames = append(usernames, username)
}
if err := rows.Close(); err != nil {
return err
}
for userID, syncTopic := range syncTopics {
if _, err := tx.Exec(migrate1To2UpdateSyncTopicNoTx, syncTopic, userID); err != nil {
for _, username := range usernames {
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
if _, err := tx.Exec(migrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil {
return err
}
}
// Migrate old "access" table to "user_access" and drop "access" and "user_old"
if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 2); err != nil {
return err
}

View File

@ -259,8 +259,8 @@ func TestManager_ChangeRole(t *testing.T) {
func TestManager_Reservations(t *testing.T) {
a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
require.Nil(t, a.ReserveAccess("ben", "ztopic", PermissionDenyAll))
require.Nil(t, a.ReserveAccess("ben", "readme", PermissionRead))
require.Nil(t, a.AddReservation("ben", "ztopic", PermissionDenyAll))
require.Nil(t, a.AddReservation("ben", "readme", PermissionRead))
require.Nil(t, a.AllowAccess("ben", "something-else", PermissionRead))
reservations, err := a.Reservations("ben")
@ -294,7 +294,7 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
}))
require.Nil(t, a.AddUser("ben", "ben", RoleUser, "unit-test"))
require.Nil(t, a.ChangeTier("ben", "pro"))
require.Nil(t, a.ReserveAccess("ben", "mytopic", PermissionDenyAll))
require.Nil(t, a.AddReservation("ben", "mytopic", PermissionDenyAll))
ben, err := a.User("ben")
require.Nil(t, err)
@ -626,11 +626,14 @@ func TestSqliteCache_Migration_From1(t *testing.T) {
everyoneGrants, err := a.Grants(Everyone)
require.Nil(t, err)
require.True(t, strings.HasPrefix(phil.ID, "u_"))
require.Equal(t, "phil", phil.Name)
require.Equal(t, RoleAdmin, phil.Role)
require.Equal(t, syncTopicLength, len(phil.SyncTopic))
require.Equal(t, 0, len(philGrants))
require.True(t, strings.HasPrefix(ben.ID, "u_"))
require.NotEqual(t, phil.ID, ben.ID)
require.Equal(t, "ben", ben.Name)
require.Equal(t, RoleUser, ben.Role)
require.Equal(t, syncTopicLength, len(ben.SyncTopic))
@ -641,6 +644,7 @@ func TestSqliteCache_Migration_From1(t *testing.T) {
require.Equal(t, "secret", benGrants[1].TopicPattern)
require.Equal(t, PermissionRead, benGrants[1].Allow)
require.Equal(t, "u_everyone", everyone.ID)
require.Equal(t, Everyone, everyone.Name)
require.Equal(t, RoleAnonymous, everyone.Role)
require.Equal(t, 1, len(everyoneGrants))

View File

@ -10,6 +10,7 @@ import (
// User is a struct that represents a user
type User struct {
ID string
Name string
Hash string // password hash (bcrypt)
Token string // Only set if token was used to log in
@ -50,6 +51,7 @@ type Prefs struct {
// Tier represents a user's account type, including its account limits
type Tier struct {
ID string
Code string
Name string
Paid bool

View File

@ -107,13 +107,18 @@ func LastString(s []string, def string) string {
// RandomString returns a random string with a given length
func RandomString(length int) string {
return RandomStringPrefix("", length)
}
// RandomStringPrefix returns a random string with a given length, with a prefix
func RandomStringPrefix(prefix string, length int) string {
randomMutex.Lock() // Who would have thought that random.Intn() is not thread-safe?!
defer randomMutex.Unlock()
b := make([]byte, length)
b := make([]byte, length-len(prefix))
for i := range b {
b[i] = randomStringCharset[random.Intn(len(randomStringCharset))]
}
return string(b)
return prefix + string(b)
}
// ValidRandomString returns true if the given string matches the format created by RandomString