1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2024-11-26 21:25:18 +01:00

Prune excess tokens per user

This commit is contained in:
binwiederhier 2023-01-05 20:22:34 -05:00
parent 60f1882bec
commit 7fa63c8e19
3 changed files with 108 additions and 29 deletions

View file

@ -45,10 +45,6 @@ import (
"account topic" sync mechanism "account topic" sync mechanism
purge accounts that were not logged into in X purge accounts that were not logged into in X
reset daily limits for users reset daily limits for users
max token issue limit
user db startup queries -> foreign keys
UI
- Feature flag for "reserve topic" feature
Sync: Sync:
- "mute" setting - "mute" setting
- figure out what settings are "web" or "phone" - figure out what settings are "web" or "phone"

View file

@ -15,16 +15,18 @@ import (
) )
const ( const (
tokenLength = 32
bcryptCost = 10 bcryptCost = 10
intentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match bcryptCost intentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match bcryptCost
userStatsQueueWriterInterval = 33 * time.Second userStatsQueueWriterInterval = 33 * time.Second
userTokenExpiryDuration = 72 * time.Hour tokenLength = 32
tokenExpiryDuration = 72 * time.Hour // Extend tokens by this much
tokenMaxCount = 10 // Only keep this many tokens in the table per user
) )
var ( var (
errNoTokenProvided = errors.New("no token provided") errNoTokenProvided = errors.New("no token provided")
errTopicOwnedByOthers = errors.New("topic owned by others") errTopicOwnedByOthers = errors.New("topic owned by others")
errNoRows = errors.New("no rows found")
) )
// Manager-related queries // Manager-related queries
@ -139,7 +141,7 @@ const (
ORDER BY a_user.topic ORDER BY a_user.topic
` `
selectOtherAccessCountQuery = ` selectOtherAccessCountQuery = `
SELECT count(*) SELECT COUNT(*)
FROM user_access FROM user_access
WHERE (topic = ? OR ? LIKE topic) WHERE (topic = ? OR ? LIKE topic)
AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?)) AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
@ -148,10 +150,22 @@ const (
deleteUserAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?)` deleteUserAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?)`
deleteTopicAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?) AND topic = ?` deleteTopicAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?) AND topic = ?`
selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE (SELECT id FROM user WHERE user = ?)`
insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES ((SELECT id FROM user WHERE user = ?), ?, ?)` insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES ((SELECT id FROM user WHERE user = ?), ?, ?)`
updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?` updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
deleteTokenQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?` deleteTokenQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires < ?` deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires < ?`
deleteExcessTokensQuery = `
DELETE FROM user_token
WHERE (user_id, token) NOT IN (
SELECT user_id, token
FROM user_token
WHERE user_id = (SELECT id FROM user WHERE user = ?)
ORDER BY expires DESC
LIMIT ?
)
;
`
) )
// Schema management queries // Schema management queries
@ -185,7 +199,6 @@ type Manager struct {
db *sql.DB db *sql.DB
defaultAccess Permission // Default permission if no ACL matches defaultAccess Permission // Default permission if no ACL matches
statsQueue map[string]*User // Username -> User, for "unimportant" user updates statsQueue map[string]*User // Username -> User, for "unimportant" user updates
tokenExpiryInterval time.Duration // Duration after which tokens expire, and by which tokens are extended
mu sync.Mutex mu sync.Mutex
} }
@ -193,11 +206,11 @@ var _ Auther = (*Manager)(nil)
// NewManager creates a new Manager instance // NewManager creates a new Manager instance
func NewManager(filename, startupQueries string, defaultAccess Permission) (*Manager, error) { func NewManager(filename, startupQueries string, defaultAccess Permission) (*Manager, error) {
return newManager(filename, startupQueries, defaultAccess, userTokenExpiryDuration, userStatsQueueWriterInterval) return newManager(filename, startupQueries, defaultAccess, userStatsQueueWriterInterval)
} }
// NewManager creates a new Manager instance // NewManager creates a new Manager instance
func newManager(filename, startupQueries string, defaultAccess Permission, tokenExpiryDuration, statsWriterInterval time.Duration) (*Manager, error) { func newManager(filename, startupQueries string, defaultAccess Permission, statsWriterInterval time.Duration) (*Manager, error) {
db, err := sql.Open("sqlite3", filename) db, err := sql.Open("sqlite3", filename)
if err != nil { if err != nil {
return nil, err return nil, err
@ -212,7 +225,6 @@ func newManager(filename, startupQueries string, defaultAccess Permission, token
db: db, db: db,
defaultAccess: defaultAccess, defaultAccess: defaultAccess,
statsQueue: make(map[string]*User), statsQueue: make(map[string]*User),
tokenExpiryInterval: tokenExpiryDuration,
} }
go manager.userStatsQueueWriter(statsWriterInterval) go manager.userStatsQueueWriter(statsWriterInterval)
return manager, nil return manager, nil
@ -253,10 +265,38 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
} }
// CreateToken generates a random token for the given user and returns it. The token expires // CreateToken generates a random token for the given user and returns it. The token expires
// after a fixed duration unless ExtendToken is called. // after a fixed duration unless ExtendToken is called. This function also prunes tokens for the
// given user, if there are too many of them.
func (a *Manager) CreateToken(user *User) (*Token, error) { func (a *Manager) CreateToken(user *User) (*Token, error) {
token, expires := util.RandomString(tokenLength), time.Now().Add(userTokenExpiryDuration) token, expires := util.RandomString(tokenLength), time.Now().Add(tokenExpiryDuration)
if _, err := a.db.Exec(insertTokenQuery, user.Name, token, expires.Unix()); err != nil { tx, err := a.db.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
if _, err := tx.Exec(insertTokenQuery, user.Name, token, expires.Unix()); err != nil {
return nil, err
}
rows, err := tx.Query(selectTokenCountQuery, user.Name)
if err != nil {
return nil, err
}
defer rows.Close()
if !rows.Next() {
return nil, errNoRows
}
var tokenCount int
if err := rows.Scan(&tokenCount); err != nil {
return nil, err
}
if tokenCount >= tokenMaxCount {
// This pruning logic is done in two queries for efficiency. The SELECT above is a lookup
// on two indices, whereas the query below is a full table scan.
if _, err := tx.Exec(deleteExcessTokensQuery, user.Name, tokenMaxCount); err != nil {
return nil, err
}
}
if err := tx.Commit(); err != nil {
return nil, err return nil, err
} }
return &Token{ return &Token{
@ -270,7 +310,7 @@ func (a *Manager) ExtendToken(user *User) (*Token, error) {
if user.Token == "" { if user.Token == "" {
return nil, errNoTokenProvided return nil, errNoTokenProvided
} }
newExpires := time.Now().Add(userTokenExpiryDuration) newExpires := time.Now().Add(tokenExpiryDuration)
if _, err := a.db.Exec(updateTokenExpiryQuery, newExpires.Unix(), user.Name, user.Token); err != nil { if _, err := a.db.Exec(updateTokenExpiryQuery, newExpires.Unix(), user.Name, user.Token); err != nil {
return nil, err return nil, err
} }
@ -600,7 +640,7 @@ func (a *Manager) CheckAllowAccess(username string, topic string) error {
} }
defer rows.Close() defer rows.Close()
if !rows.Next() { if !rows.Next() {
return errors.New("no rows found") return errNoRows
} }
var otherCount int var otherCount int
if err := rows.Scan(&otherCount); err != nil { if err := rows.Scan(&otherCount); err != nil {

View file

@ -369,8 +369,51 @@ func TestManager_Token_Extend(t *testing.T) {
require.True(t, token.Expires.Unix() < extendedToken.Expires.Unix()) require.True(t, token.Expires.Unix() < extendedToken.Expires.Unix())
} }
func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
// Try to extend token for user without token
u, err := a.User("ben")
require.Nil(t, err)
// Tokens
baseTime := time.Now().Add(24 * time.Hour)
tokens := make([]string, 0)
for i := 0; i < 12; i++ {
token, err := a.CreateToken(u)
require.Nil(t, err)
require.NotEmpty(t, token.Value)
tokens = append(tokens, token.Value)
// Manually modify expiry date to avoid sorting issues (this is a hack)
_, err = a.db.Exec(`UPDATE user_token SET expires=? WHERE token=?`, baseTime.Add(time.Duration(i)*time.Minute).Unix(), token.Value)
require.Nil(t, err)
}
_, err = a.AuthenticateToken(tokens[0])
require.Equal(t, ErrUnauthenticated, err)
_, err = a.AuthenticateToken(tokens[1])
require.Equal(t, ErrUnauthenticated, err)
for i := 2; i < 12; i++ {
userWithToken, err := a.AuthenticateToken(tokens[i])
require.Nil(t, err, "token[%d]=%s failed", i, tokens[i])
require.Equal(t, "ben", userWithToken.Name)
require.Equal(t, tokens[i], userWithToken.Token)
}
var count int
rows, err := a.db.Query(`SELECT COUNT(*) FROM user_token`)
require.Nil(t, err)
require.True(t, rows.Next())
require.Nil(t, rows.Scan(&count))
require.Equal(t, 10, count)
}
func TestManager_EnqueueStats(t *testing.T) { func TestManager_EnqueueStats(t *testing.T) {
a, err := newManager(filepath.Join(t.TempDir(), "db"), PermissionReadWrite, time.Hour, 1500*time.Millisecond) a, err := newManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond)
require.Nil(t, err) require.Nil(t, err)
require.Nil(t, a.AddUser("ben", "ben", RoleUser)) require.Nil(t, a.AddUser("ben", "ben", RoleUser))
@ -400,7 +443,7 @@ func TestManager_EnqueueStats(t *testing.T) {
} }
func TestManager_ChangeSettings(t *testing.T) { func TestManager_ChangeSettings(t *testing.T) {
a, err := newManager(filepath.Join(t.TempDir(), "db"), PermissionReadWrite, time.Hour, 1500*time.Millisecond) a, err := newManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond)
require.Nil(t, err) require.Nil(t, err)
require.Nil(t, a.AddUser("ben", "ben", RoleUser)) require.Nil(t, a.AddUser("ben", "ben", RoleUser))
@ -482,7 +525,7 @@ func TestSqliteCache_Migration_From1(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
// Create manager to trigger migration // Create manager to trigger migration
a := newTestManagerFromFile(t, filename, PermissionDenyAll, userTokenExpiryDuration, userStatsQueueWriterInterval) a := newTestManagerFromFile(t, filename, "", PermissionDenyAll, userStatsQueueWriterInterval)
checkSchemaVersion(t, a.db) checkSchemaVersion(t, a.db)
users, err := a.Users() users, err := a.Users()
@ -530,11 +573,11 @@ func checkSchemaVersion(t *testing.T, db *sql.DB) {
} }
func newTestManager(t *testing.T, defaultAccess Permission) *Manager { func newTestManager(t *testing.T, defaultAccess Permission) *Manager {
return newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), defaultAccess, userTokenExpiryDuration, userStatsQueueWriterInterval) return newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), "", defaultAccess, userStatsQueueWriterInterval)
} }
func newTestManagerFromFile(t *testing.T, filename string, defaultAccess Permission, tokenExpiryDuration, statsWriterInterval time.Duration) *Manager { func newTestManagerFromFile(t *testing.T, filename, startupQueries string, defaultAccess Permission, statsWriterInterval time.Duration) *Manager {
a, err := newManager(filename, defaultAccess, tokenExpiryDuration, statsWriterInterval) a, err := newManager(filename, startupQueries, defaultAccess, statsWriterInterval)
require.Nil(t, err) require.Nil(t, err)
return a return a
} }